spark.shuffle.file.buffer.kb |
- 100 |
+ 32 |
Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers
reduce the number of disk seeks and system calls made in creating intermediate shuffle files.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 1d5d3762ed8e9..fd0b9556c7d54 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -271,6 +271,7 @@ class PythonMLLibAPI extends Serializable {
.setNumIterations(numIterations)
.setRegParam(regParam)
.setStepSize(stepSize)
+ .setMiniBatchFraction(miniBatchFraction)
if (regType == "l2") {
lrAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
@@ -341,16 +342,27 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ initialWeightsBA: Array[Byte],
+ regType: String,
+ intercept: Boolean): java.util.List[java.lang.Object] = {
+ val SVMAlg = new SVMWithSGD()
+ SVMAlg.setIntercept(intercept)
+ SVMAlg.optimizer
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setStepSize(stepSize)
+ .setMiniBatchFraction(miniBatchFraction)
+ if (regType == "l2") {
+ SVMAlg.optimizer.setUpdater(new SquaredL2Updater)
+ } else if (regType == "l1") {
+ SVMAlg.optimizer.setUpdater(new L1Updater)
+ } else if (regType != "none") {
+ throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ + " Can only be initialized using the following string values: [l1, l2, none].")
+ }
trainRegressionModel(
(data, initialWeights) =>
- SVMWithSGD.train(
- data,
- numIterations,
- stepSize,
- regParam,
- miniBatchFraction,
- initialWeights),
+ SVMAlg.run(data, initialWeights),
dataBytesJRDD,
initialWeightsBA)
}
@@ -363,15 +375,28 @@ class PythonMLLibAPI extends Serializable {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ initialWeightsBA: Array[Byte],
+ regParam: Double,
+ regType: String,
+ intercept: Boolean): java.util.List[java.lang.Object] = {
+ val LogRegAlg = new LogisticRegressionWithSGD()
+ LogRegAlg.setIntercept(intercept)
+ LogRegAlg.optimizer
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setStepSize(stepSize)
+ .setMiniBatchFraction(miniBatchFraction)
+ if (regType == "l2") {
+ LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
+ } else if (regType == "l1") {
+ LogRegAlg.optimizer.setUpdater(new L1Updater)
+ } else if (regType != "none") {
+ throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ + " Can only be initialized using the following string values: [l1, l2, none].")
+ }
trainRegressionModel(
(data, initialWeights) =>
- LogisticRegressionWithSGD.train(
- data,
- numIterations,
- stepSize,
- miniBatchFraction,
- initialWeights),
+ LogRegAlg.run(data, initialWeights),
dataBytesJRDD,
initialWeightsBA)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 87c81e7b0bd2f..3bf44ad7c44e3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -19,16 +19,17 @@ package org.apache.spark.mllib.feature
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import scala.util.Random
import com.github.fommil.netlib.BLAS.{getInstance => blas}
-import org.apache.spark.{HashPartitioner, Logging}
+
+import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.rdd._
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
/**
* Entry in vocabulary
@@ -58,29 +59,63 @@ private case class VocabWord(
* Efficient Estimation of Word Representations in Vector Space
* and
* Distributed Representations of Words and Phrases and their Compositionality.
- * @param size vector dimension
- * @param startingAlpha initial learning rate
- * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
- * @param numIterations number of iterations to run, should be smaller than or equal to parallelism
*/
@Experimental
-class Word2Vec(
- val size: Int,
- val startingAlpha: Double,
- val parallelism: Int,
- val numIterations: Int) extends Serializable with Logging {
+class Word2Vec extends Serializable with Logging {
+
+ private var vectorSize = 100
+ private var startingAlpha = 0.025
+ private var numPartitions = 1
+ private var numIterations = 1
+ private var seed = Utils.random.nextLong()
+
+ /**
+ * Sets vector size (default: 100).
+ */
+ def setVectorSize(vectorSize: Int): this.type = {
+ this.vectorSize = vectorSize
+ this
+ }
+
+ /**
+ * Sets initial learning rate (default: 0.025).
+ */
+ def setLearningRate(learningRate: Double): this.type = {
+ this.startingAlpha = learningRate
+ this
+ }
/**
- * Word2Vec with a single thread.
+ * Sets number of partitions (default: 1). Use a small number for accuracy.
*/
- def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1)
+ def setNumPartitions(numPartitions: Int): this.type = {
+ require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions")
+ this.numPartitions = numPartitions
+ this
+ }
+
+ /**
+ * Sets number of iterations (default: 1), which should be smaller than or equal to number of
+ * partitions.
+ */
+ def setNumIterations(numIterations: Int): this.type = {
+ this.numIterations = numIterations
+ this
+ }
+
+ /**
+ * Sets random seed (default: a random long integer).
+ */
+ def setSeed(seed: Long): this.type = {
+ this.seed = seed
+ this
+ }
private val EXP_TABLE_SIZE = 1000
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
private val MAX_SENTENCE_LENGTH = 1000
- private val layer1Size = size
- private val modelPartitionNum = 100
+ private val layer1Size = vectorSize
/** context words from [-window, window] */
private val window = 5
@@ -94,12 +129,12 @@ class Word2Vec(
private var vocabHash = mutable.HashMap.empty[String, Int]
private var alpha = startingAlpha
- private def learnVocab(words:RDD[String]): Unit = {
+ private def learnVocab(words: RDD[String]): Unit = {
vocab = words.map(w => (w, 1))
.reduceByKey(_ + _)
.map(x => VocabWord(
- x._1,
- x._2,
+ x._1,
+ x._2,
new Array[Int](MAX_CODE_LENGTH),
new Array[Int](MAX_CODE_LENGTH),
0))
@@ -245,23 +280,24 @@ class Word2Vec(
}
}
- val newSentences = sentences.repartition(parallelism).cache()
+ val newSentences = sentences.repartition(numPartitions).cache()
+ val initRandom = new XORShiftRandom(seed)
var syn0Global =
- Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size)
+ Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size)
var syn1Global = new Array[Float](vocabSize * layer1Size)
-
- for(iter <- 1 to numIterations) {
- val (aggSyn0, aggSyn1, _, _) =
- // TODO: broadcast temp instead of serializing it directly
- // or initialize the model in each executor
- newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))(
- seqOp = (c, v) => (c, v) match {
+
+ for (k <- 1 to numIterations) {
+ val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
+ val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
+ val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
- var wc = wordCount
+ var wc = wordCount
if (wordCount - lastWordCount > 10000) {
lwc = wordCount
- alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1))
+ // TODO: discount by iteration?
+ alpha =
+ startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
}
@@ -269,8 +305,7 @@ class Word2Vec(
var pos = 0
while (pos < sentence.size) {
val word = sentence(pos)
- // TODO: fix random seed
- val b = Random.nextInt(window)
+ val b = random.nextInt(window)
// Train Skip-gram
var a = b
while (a < window * 2 + 1 - b) {
@@ -280,7 +315,7 @@ class Word2Vec(
val lastWord = sentence(c)
val l1 = lastWord * layer1Size
val neu1e = new Array[Float](layer1Size)
- // Hierarchical softmax
+ // Hierarchical softmax
var d = 0
while (d < bcVocab.value(word).codeLen) {
val l2 = bcVocab.value(word).point(d) * layer1Size
@@ -303,44 +338,44 @@ class Word2Vec(
pos += 1
}
(syn0, syn1, lwc, wc)
- },
- combOp = (c1, c2) => (c1, c2) match {
- case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
- val n = syn0_1.length
- val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
- val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
- blas.sscal(n, weight1, syn0_1, 1)
- blas.sscal(n, weight1, syn1_1, 1)
- blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
- blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
- (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
- })
+ }
+ Iterator(model)
+ }
+ val (aggSyn0, aggSyn1, _, _) =
+ partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
+ val n = syn0_1.length
+ val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
+ val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
+ blas.sscal(n, weight1, syn0_1, 1)
+ blas.sscal(n, weight1, syn1_1, 1)
+ blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
+ blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
+ (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
+ }
syn0Global = aggSyn0
syn1Global = aggSyn1
}
newSentences.unpersist()
- val wordMap = new Array[(String, Array[Float])](vocabSize)
+ val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
var i = 0
while (i < vocabSize) {
val word = bcVocab.value(i).word
val vector = new Array[Float](layer1Size)
Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size)
- wordMap(i) = (word, vector)
+ word2VecMap += word -> vector
i += 1
}
- val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
- .partitionBy(new HashPartitioner(modelPartitionNum))
- .persist(StorageLevel.MEMORY_AND_DISK)
-
- new Word2VecModel(modelRDD)
+
+ new Word2VecModel(word2VecMap.toMap)
}
}
/**
* Word2Vec model
-*/
-class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable {
+ */
+class Word2VecModel private[mllib] (
+ private val model: Map[String, Array[Float]]) extends Serializable {
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
@@ -357,11 +392,12 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
* @return vector representation of word
*/
def transform(word: String): Vector = {
- val result = model.lookup(word)
- if (result.isEmpty) {
- throw new IllegalStateException(s"$word not in vocabulary")
+ model.get(word) match {
+ case Some(vec) =>
+ Vectors.dense(vec.map(_.toDouble))
+ case None =>
+ throw new IllegalStateException(s"$word not in vocabulary")
}
- else Vectors.dense(result(0).map(_.toDouble))
}
/**
@@ -392,33 +428,13 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
*/
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
- val topK = model.map { case(w, vec) =>
- (cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) }
- .sortByKey(ascending = false)
- .take(num + 1)
- .map(_.swap)
- .tail
-
- topK
- }
-}
-
-object Word2Vec{
- /**
- * Train Word2Vec model
- * @param input RDD of words
- * @param size vector dimension
- * @param startingAlpha initial learning rate
- * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
- * @param numIterations number of iterations, should be smaller than or equal to parallelism
- * @return Word2Vec model
- */
- def train[S <: Iterable[String]](
- input: RDD[S],
- size: Int,
- startingAlpha: Double,
- parallelism: Int = 1,
- numIterations:Int = 1): Word2VecModel = {
- new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input)
+ // TODO: optimize top-k
+ val fVector = vector.toArray.map(_.toFloat)
+ model.mapValues(vec => cosineSimilarity(fVector, vec))
+ .toSeq
+ .sortBy(- _._2)
+ .take(num + 1)
+ .tail
+ .toArray
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
index b5db39b68a223..e34335d89eb75 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
@@ -30,29 +30,22 @@ class Word2VecSuite extends FunSuite with LocalSparkContext {
val localDoc = Seq(sentence, sentence)
val doc = sc.parallelize(localDoc)
.map(line => line.split(" ").toSeq)
- val size = 10
- val startingAlpha = 0.025
- val window = 2
- val minCount = 2
- val num = 2
-
- val model = Word2Vec.train(doc, size, startingAlpha)
+ val model = new Word2Vec().setVectorSize(10).setSeed(42L).fit(doc)
val syms = model.findSynonyms("a", 2)
- assert(syms.length == num)
+ assert(syms.length == 2)
assert(syms(0)._1 == "b")
assert(syms(1)._1 == "c")
}
-
test("Word2VecModel") {
val num = 2
- val localModel = Seq(
+ val word2VecMap = Map(
("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
)
- val model = new Word2VecModel(sc.parallelize(localModel, 2))
+ val model = new Word2VecModel(word2VecMap)
val syms = model.findSynonyms("china", num)
assert(syms.length == num)
assert(syms(0)._1 == "taiwan")
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 2bbb9c3fca315..5ec1a8084d269 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -73,11 +73,36 @@ def predict(self, x):
class LogisticRegressionWithSGD(object):
@classmethod
- def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None):
- """Train a logistic regression model on the given data."""
+ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
+ initialWeights=None, regParam=1.0, regType=None, intercept=False):
+ """
+ Train a logistic regression model on the given data.
+
+ @param data: The training data.
+ @param iterations: The number of iterations (default: 100).
+ @param step: The step parameter used in SGD
+ (default: 1.0).
+ @param miniBatchFraction: Fraction of data to be used for each SGD
+ iteration.
+ @param initialWeights: The initial weights (default: None).
+ @param regParam: The regularizer parameter (default: 1.0).
+ @param regType: The type of regularizer used for training
+ our model.
+ Allowed values: "l1" for using L1Updater,
+ "l2" for using
+ SquaredL2Updater,
+ "none" for no regularizer.
+ (default: "none")
+ @param intercept: Boolean parameter which indicates the use
+ or not of the augmented representation for
+ training data (i.e. whether bias features
+ are activated or not).
+ """
sc = data.context
+ if regType is None:
+ regType = "none"
train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD(
- d._jrdd, iterations, step, miniBatchFraction, i)
+ d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept)
return _regression_train_wrapper(sc, train_func, LogisticRegressionModel, data,
initialWeights)
@@ -115,11 +140,35 @@ def predict(self, x):
class SVMWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0, regParam=1.0,
- miniBatchFraction=1.0, initialWeights=None):
- """Train a support vector machine on the given data."""
+ miniBatchFraction=1.0, initialWeights=None, regType=None, intercept=False):
+ """
+ Train a support vector machine on the given data.
+
+ @param data: The training data.
+ @param iterations: The number of iterations (default: 100).
+ @param step: The step parameter used in SGD
+ (default: 1.0).
+ @param regParam: The regularizer parameter (default: 1.0).
+ @param miniBatchFraction: Fraction of data to be used for each SGD
+ iteration.
+ @param initialWeights: The initial weights (default: None).
+ @param regType: The type of regularizer used for training
+ our model.
+ Allowed values: "l1" for using L1Updater,
+ "l2" for using
+ SquaredL2Updater,
+ "none" for no regularizer.
+ (default: "none")
+ @param intercept: Boolean parameter which indicates the use
+ or not of the augmented representation for
+ training data (i.e. whether bias features
+ are activated or not).
+ """
sc = data.context
+ if regType is None:
+ regType = "none"
train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD(
- d._jrdd, iterations, step, regParam, miniBatchFraction, i)
+ d._jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept)
return _regression_train_wrapper(sc, train_func, SVMModel, data, initialWeights)
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 1a829c6fafe03..f1093701ddc89 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -672,12 +672,12 @@ def _infer_schema_type(obj, dataType):
ByteType: (int, long),
ShortType: (int, long),
IntegerType: (int, long),
- LongType: (int, long),
+ LongType: (long,),
FloatType: (float,),
DoubleType: (float,),
DecimalType: (decimal.Decimal,),
StringType: (str, unicode),
- TimestampType: (datetime.datetime, datetime.time, datetime.date),
+ TimestampType: (datetime.datetime,),
ArrayType: (list, tuple, array),
MapType: (dict,),
StructType: (tuple, list),
@@ -1042,12 +1042,15 @@ def applySchema(self, rdd, schema):
[Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
>>> from datetime import datetime
- >>> rdd = sc.parallelize([(127, -32768, 1.0,
+ >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
... datetime(2010, 1, 1, 1, 1, 1),
... {"a": 1}, (2,), [1, 2, 3], None)])
>>> schema = StructType([
- ... StructField("byte", ByteType(), False),
- ... StructField("short", ShortType(), False),
+ ... StructField("byte1", ByteType(), False),
+ ... StructField("byte2", ByteType(), False),
+ ... StructField("short1", ShortType(), False),
+ ... StructField("short2", ShortType(), False),
+ ... StructField("int", IntegerType(), False),
... StructField("float", FloatType(), False),
... StructField("time", TimestampType(), False),
... StructField("map",
@@ -1056,11 +1059,19 @@ def applySchema(self, rdd, schema):
... StructType([StructField("b", ShortType(), False)]), False),
... StructField("list", ArrayType(ByteType(), False), False),
... StructField("null", DoubleType(), True)])
- >>> srdd = sqlCtx.applySchema(rdd, schema).map(
- ... lambda x: (x.byte, x.short, x.float, x.time,
+ >>> srdd = sqlCtx.applySchema(rdd, schema)
+ >>> results = srdd.map(
+ ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time,
... x.map["a"], x.struct.b, x.list, x.null))
- >>> srdd.collect()[0]
- (127, -32768, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
+ >>> results.collect()[0]
+ (127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
+
+ >>> srdd.registerTempTable("table2")
+ >>> sqlCtx.sql(
+ ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
+ ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " +
+ ... "float + 1.1 as float FROM table2").collect()
+ [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.1)]
>>> rdd = sc.parallelize([(127, -32768, 1.0,
... datetime(2010, 1, 1, 1, 1, 1),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 2ba68cab115fb..0293d578b0b92 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -48,6 +48,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
Batch("Resolution", fixedPoint,
ResolveReferences ::
ResolveRelations ::
+ ResolveSortReferences ::
NewRelationInstances ::
ImplicitGenerate ::
StarExpansion ::
@@ -113,13 +114,58 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
q transformExpressions {
case u @ UnresolvedAttribute(name) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
- val result = q.resolve(name).getOrElse(u)
+ val result = q.resolveChildren(name).getOrElse(u)
logDebug(s"Resolving $u to $result")
result
}
}
}
+ /**
+ * In many dialects of SQL is it valid to sort by attributes that are not present in the SELECT
+ * clause. This rule detects such queries and adds the required attributes to the original
+ * projection, so that they will be available during sorting. Another projection is added to
+ * remove these attributes after sorting.
+ */
+ object ResolveSortReferences extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved =>
+ val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
+ val resolved = unresolved.flatMap(child.resolveChildren)
+ val requiredAttributes = resolved.collect { case a: Attribute => a }.toSet
+
+ val missingInProject = requiredAttributes -- p.output
+ if (missingInProject.nonEmpty) {
+ // Add missing attributes and then project them away after the sort.
+ Project(projectList,
+ Sort(ordering,
+ Project(projectList ++ missingInProject, child)))
+ } else {
+ s // Nothing we can do here. Return original plan.
+ }
+ case s @ Sort(ordering, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved =>
+ val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
+ // A small hack to create an object that will allow us to resolve any references that
+ // refer to named expressions that are present in the grouping expressions.
+ val groupingRelation = LocalRelation(
+ grouping.collect { case ne: NamedExpression => ne.toAttribute }
+ )
+
+ logWarning(s"Grouping expressions: $groupingRelation")
+ val resolved = unresolved.flatMap(groupingRelation.resolve).toSet
+ val missingInAggs = resolved -- a.outputSet
+ logWarning(s"Resolved: $resolved Missing in aggs: $missingInAggs")
+ if (missingInAggs.nonEmpty) {
+ // Add missing grouping exprs and then project them away after the sort.
+ Project(a.output,
+ Sort(ordering,
+ Aggregate(grouping, aggs ++ missingInAggs, child)))
+ } else {
+ s // Nothing we can do here. Return original plan.
+ }
+ }
+ }
+
/**
* Replaces [[UnresolvedFunction]]s with concrete [[catalyst.expressions.Expression Expressions]].
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 888cb08e95f06..278569f0cb14a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -72,16 +72,29 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
def childrenResolved: Boolean = !children.exists(!_.resolved)
/**
- * Optionally resolves the given string to a [[NamedExpression]]. The attribute is expressed as
+ * Optionally resolves the given string to a [[NamedExpression]] using the input from all child
+ * nodes of this LogicalPlan. The attribute is expressed as
* as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
*/
- def resolve(name: String): Option[NamedExpression] = {
+ def resolveChildren(name: String): Option[NamedExpression] =
+ resolve(name, children.flatMap(_.output))
+
+ /**
+ * Optionally resolves the given string to a [[NamedExpression]] based on the output of this
+ * LogicalPlan. The attribute is expressed as string in the following form:
+ * `[scope].AttributeName.[nested].[fields]...`.
+ */
+ def resolve(name: String): Option[NamedExpression] =
+ resolve(name, output)
+
+ /** Performs attribute resolution given a name and a sequence of possible attributes. */
+ protected def resolve(name: String, input: Seq[Attribute]): Option[NamedExpression] = {
val parts = name.split("\\.")
// Collect all attributes that are output by this nodes children where either the first part
// matches the name or where the first part matches the scope and the second part matches the
// name. Return these matches along with any remaining parts, which represent dotted access to
// struct fields.
- val options = children.flatMap(_.output).flatMap { option =>
+ val options = input.flatMap { option =>
// If the first part of the desired name matches a qualifier for this possible match, drop it.
val remainingParts =
if (option.qualifiers.contains(parts.head) && parts.size > 1) parts.drop(1) else parts
@@ -89,15 +102,15 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
}
options.distinct match {
- case (a, Nil) :: Nil => Some(a) // One match, no nested fields, use it.
+ case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it.
// One match, but we also need to extract the requested nested field.
- case (a, nestedFields) :: Nil =>
+ case Seq((a, nestedFields)) =>
a.dataType match {
case StructType(fields) =>
Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)())
case _ => None // Don't know how to resolve these field references
}
- case Nil => None // No matches.
+ case Seq() => None // No matches.
case ambiguousReferences =>
throw new TreeNodeException(
this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index ecd5fbaa0b094..71d338d21d0f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -491,7 +491,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
new java.sql.Timestamp(c.getTime().getTime())
case (c: Int, ByteType) => c.toByte
+ case (c: Long, ByteType) => c.toByte
case (c: Int, ShortType) => c.toShort
+ case (c: Long, ShortType) => c.toShort
+ case (c: Long, IntegerType) => c.toInt
case (c: Double, FloatType) => c.toFloat
case (c, StringType) if !c.isInstanceOf[String] => c.toString
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index c416a745739b3..7e7bb2859bbcd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -118,7 +118,7 @@ private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(BINARY)
private[sql] class GenericColumnBuilder extends ComplexColumnBuilder(GENERIC)
private[sql] object ColumnBuilder {
- val DEFAULT_INITIAL_BUFFER_SIZE = 10 * 1024 * 104
+ val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024
private[columnar] def ensureFreeSpace(orig: ByteBuffer, size: Int) = {
if (orig.remaining >= size) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
index d008806eedbe1..f631ee76fcd78 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
@@ -36,9 +36,9 @@ import org.apache.spark.sql.Row
* }}}
*/
private[sql] trait NullableColumnBuilder extends ColumnBuilder {
- private var nulls: ByteBuffer = _
+ protected var nulls: ByteBuffer = _
+ protected var nullCount: Int = _
private var pos: Int = _
- private var nullCount: Int = _
abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) {
nulls = ByteBuffer.allocate(1024)
@@ -78,4 +78,9 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder {
buffer.rewind()
buffer
}
+
+ protected def buildNonNulls(): ByteBuffer = {
+ nulls.limit(nulls.position()).rewind()
+ super.build()
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
index 6ad12a0dcb64d..a5826bb033e41 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
@@ -46,8 +46,6 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType]
this: NativeColumnBuilder[T] with WithCompressionSchemes =>
- import CompressionScheme._
-
var compressionEncoders: Seq[Encoder[T]] = _
abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) {
@@ -81,28 +79,32 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType]
}
}
- abstract override def build() = {
- val rawBuffer = super.build()
+ override def build() = {
+ val nonNullBuffer = buildNonNulls()
+ val typeId = nonNullBuffer.getInt()
val encoder: Encoder[T] = {
val candidate = compressionEncoders.minBy(_.compressionRatio)
if (isWorthCompressing(candidate)) candidate else PassThrough.encoder
}
- val headerSize = columnHeaderSize(rawBuffer)
+ // Header = column type ID + null count + null positions
+ val headerSize = 4 + 4 + nulls.limit()
val compressedSize = if (encoder.compressedSize == 0) {
- rawBuffer.limit - headerSize
+ nonNullBuffer.remaining()
} else {
encoder.compressedSize
}
- // Reserves 4 bytes for compression scheme ID
val compressedBuffer = ByteBuffer
+ // Reserves 4 bytes for compression scheme ID
.allocate(headerSize + 4 + compressedSize)
.order(ByteOrder.nativeOrder)
-
- copyColumnHeader(rawBuffer, compressedBuffer)
+ // Write the header
+ .putInt(typeId)
+ .putInt(nullCount)
+ .put(nulls)
logInfo(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}")
- encoder.compress(rawBuffer, compressedBuffer, columnType)
+ encoder.compress(nonNullBuffer, compressedBuffer, columnType)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
index ba1810dd2ae66..7797f75177893 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
@@ -67,22 +67,6 @@ private[sql] object CompressionScheme {
s"Unrecognized compression scheme type ID: $typeId"))
}
- def copyColumnHeader(from: ByteBuffer, to: ByteBuffer) {
- // Writes column type ID
- to.putInt(from.getInt())
-
- // Writes null count
- val nullCount = from.getInt()
- to.putInt(nullCount)
-
- // Writes null positions
- var i = 0
- while (i < nullCount) {
- to.putInt(from.getInt())
- i += 1
- }
- }
-
def columnHeaderSize(columnBuffer: ByteBuffer): Int = {
val header = columnBuffer.duplicate().order(ByteOrder.nativeOrder)
val nullCount = header.getInt(4)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
index 6d688ea95cfc0..72c19fa31d980 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
@@ -42,4 +42,3 @@ object TestCompressibleColumnBuilder {
builder
}
}
-
diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
index 7fac90fdc596d..c6f60c18804a4 100644
--- a/sql/hive-thriftserver/pom.xml
+++ b/sql/hive-thriftserver/pom.xml
@@ -29,7 +29,7 @@
org.apache.spark
spark-hive-thriftserver_2.10
jar
- Spark Project Hive
+ Spark Project Hive Thrift Server
http://spark.apache.org/
hive-thriftserver
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
new file mode 100644
index 0000000000000..635a9fb0d56cb
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.sql.{SQLConf, QueryTest}
+import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin}
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHive._
+
+/**
+ * A collection of hive query tests where we generate the answers ourselves instead of depending on
+ * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is
+ * valid, but Hive currently cannot execute it.
+ */
+class SQLQuerySuite extends QueryTest {
+ test("ordering not in select") {
+ checkAnswer(
+ sql("SELECT key FROM src ORDER BY value"),
+ sql("SELECT key FROM (SELECT key, value FROM src ORDER BY value) a").collect().toSeq)
+ }
+
+ test("ordering not in agg") {
+ checkAnswer(
+ sql("SELECT key FROM src GROUP BY key, value ORDER BY value"),
+ sql("""
+ SELECT key
+ FROM (
+ SELECT key, value
+ FROM src
+ GROUP BY key, value
+ ORDER BY value) a""").collect().toSeq)
+ }
+}
|