diff --git a/core/pom.xml b/core/pom.xml index 7c60cf10c3dc2..6d8be37037729 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -150,7 +150,7 @@ org.json4s json4s-jackson_${scala.binary.version} - 3.2.6 + 3.2.10 colt diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index a76a070b5b863..8947e66f4577c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -96,17 +96,23 @@ class JdbcRDD[T: ClassTag]( override def close() { try { - if (null != rs && ! rs.isClosed()) rs.close() + if (null != rs && ! rs.isClosed()) { + rs.close() + } } catch { case e: Exception => logWarning("Exception closing resultset", e) } try { - if (null != stmt && ! stmt.isClosed()) stmt.close() + if (null != stmt && ! stmt.isClosed()) { + stmt.close() + } } catch { case e: Exception => logWarning("Exception closing statement", e) } try { - if (null != conn && ! stmt.isClosed()) conn.close() + if (null != conn && ! conn.isClosed()) { + conn.close() + } logInfo("closed connection") } catch { case e: Exception => logWarning("Exception closing connection", e) @@ -120,3 +126,4 @@ object JdbcRDD { Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) } } + diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 9a356d0dbaf17..24db2f287a47b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -40,7 +40,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private val ser = Serializer.getSerializer(dep.serializer.orNull) private val conf = SparkEnv.get.conf - private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 private var sorter: ExternalSorter[K, V, _] = null private var outputFile: File = null diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 28aa35bc7e147..f9fdffae8bd8f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -73,7 +73,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { val sortBasedShuffle = conf.get("spark.shuffle.manager", "") == classOf[SortShuffleManager].getName - private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 /** * Contains all the state related to a particular shuffle. This includes a pool of unused diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index cc0423856cefb..260a5c3888aa7 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -101,7 +101,7 @@ class ExternalAppendOnlyMap[K, V, C]( private var _memoryBytesSpilled = 0L private var _diskBytesSpilled = 0L - private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 101c83b264f63..3f93afd57b3ad 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -84,7 +84,7 @@ private[spark] class ExternalSorter[K, V, C]( private val conf = SparkEnv.get.conf private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 // Size of object batches when reading/writing from serializers. // diff --git a/docs/configuration.md b/docs/configuration.md index a915e5a4961b6..5e3eb0f0871af 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -266,7 +266,7 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.file.buffer.kb - 100 + 32 Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers reduce the number of disk seeks and system calls made in creating intermediate shuffle files. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 1d5d3762ed8e9..fd0b9556c7d54 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -271,6 +271,7 @@ class PythonMLLibAPI extends Serializable { .setNumIterations(numIterations) .setRegParam(regParam) .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) if (regType == "l2") { lrAlg.optimizer.setUpdater(new SquaredL2Updater) } else if (regType == "l1") { @@ -341,16 +342,27 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeightsBA: Array[Byte], + regType: String, + intercept: Boolean): java.util.List[java.lang.Object] = { + val SVMAlg = new SVMWithSGD() + SVMAlg.setIntercept(intercept) + SVMAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) + if (regType == "l2") { + SVMAlg.optimizer.setUpdater(new SquaredL2Updater) + } else if (regType == "l1") { + SVMAlg.optimizer.setUpdater(new L1Updater) + } else if (regType != "none") { + throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: [l1, l2, none].") + } trainRegressionModel( (data, initialWeights) => - SVMWithSGD.train( - data, - numIterations, - stepSize, - regParam, - miniBatchFraction, - initialWeights), + SVMAlg.run(data, initialWeights), dataBytesJRDD, initialWeightsBA) } @@ -363,15 +375,28 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeightsBA: Array[Byte], + regParam: Double, + regType: String, + intercept: Boolean): java.util.List[java.lang.Object] = { + val LogRegAlg = new LogisticRegressionWithSGD() + LogRegAlg.setIntercept(intercept) + LogRegAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) + if (regType == "l2") { + LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) + } else if (regType == "l1") { + LogRegAlg.optimizer.setUpdater(new L1Updater) + } else if (regType != "none") { + throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: [l1, l2, none].") + } trainRegressionModel( (data, initialWeights) => - LogisticRegressionWithSGD.train( - data, - numIterations, - stepSize, - miniBatchFraction, - initialWeights), + LogRegAlg.run(data, initialWeights), dataBytesJRDD, initialWeightsBA) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 87c81e7b0bd2f..3bf44ad7c44e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -19,16 +19,17 @@ package org.apache.spark.mllib.feature import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.{HashPartitioner, Logging} + +import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd._ -import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom /** * Entry in vocabulary @@ -58,29 +59,63 @@ private case class VocabWord( * Efficient Estimation of Word Representations in Vector Space * and * Distributed Representations of Words and Phrases and their Compositionality. - * @param size vector dimension - * @param startingAlpha initial learning rate - * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) - * @param numIterations number of iterations to run, should be smaller than or equal to parallelism */ @Experimental -class Word2Vec( - val size: Int, - val startingAlpha: Double, - val parallelism: Int, - val numIterations: Int) extends Serializable with Logging { +class Word2Vec extends Serializable with Logging { + + private var vectorSize = 100 + private var startingAlpha = 0.025 + private var numPartitions = 1 + private var numIterations = 1 + private var seed = Utils.random.nextLong() + + /** + * Sets vector size (default: 100). + */ + def setVectorSize(vectorSize: Int): this.type = { + this.vectorSize = vectorSize + this + } + + /** + * Sets initial learning rate (default: 0.025). + */ + def setLearningRate(learningRate: Double): this.type = { + this.startingAlpha = learningRate + this + } /** - * Word2Vec with a single thread. + * Sets number of partitions (default: 1). Use a small number for accuracy. */ - def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1) + def setNumPartitions(numPartitions: Int): this.type = { + require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions") + this.numPartitions = numPartitions + this + } + + /** + * Sets number of iterations (default: 1), which should be smaller than or equal to number of + * partitions. + */ + def setNumIterations(numIterations: Int): this.type = { + this.numIterations = numIterations + this + } + + /** + * Sets random seed (default: a random long integer). + */ + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 private val MAX_SENTENCE_LENGTH = 1000 - private val layer1Size = size - private val modelPartitionNum = 100 + private val layer1Size = vectorSize /** context words from [-window, window] */ private val window = 5 @@ -94,12 +129,12 @@ class Word2Vec( private var vocabHash = mutable.HashMap.empty[String, Int] private var alpha = startingAlpha - private def learnVocab(words:RDD[String]): Unit = { + private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) .map(x => VocabWord( - x._1, - x._2, + x._1, + x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0)) @@ -245,23 +280,24 @@ class Word2Vec( } } - val newSentences = sentences.repartition(parallelism).cache() + val newSentences = sentences.repartition(numPartitions).cache() + val initRandom = new XORShiftRandom(seed) var syn0Global = - Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size) + Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size) var syn1Global = new Array[Float](vocabSize * layer1Size) - - for(iter <- 1 to numIterations) { - val (aggSyn0, aggSyn1, _, _) = - // TODO: broadcast temp instead of serializing it directly - // or initialize the model in each executor - newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))( - seqOp = (c, v) => (c, v) match { + + for (k <- 1 to numIterations) { + val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => + val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) + val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount - var wc = wordCount + var wc = wordCount if (wordCount - lastWordCount > 10000) { lwc = wordCount - alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1)) + // TODO: discount by iteration? + alpha = + startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 logInfo("wordCount = " + wordCount + ", alpha = " + alpha) } @@ -269,8 +305,7 @@ class Word2Vec( var pos = 0 while (pos < sentence.size) { val word = sentence(pos) - // TODO: fix random seed - val b = Random.nextInt(window) + val b = random.nextInt(window) // Train Skip-gram var a = b while (a < window * 2 + 1 - b) { @@ -280,7 +315,7 @@ class Word2Vec( val lastWord = sentence(c) val l1 = lastWord * layer1Size val neu1e = new Array[Float](layer1Size) - // Hierarchical softmax + // Hierarchical softmax var d = 0 while (d < bcVocab.value(word).codeLen) { val l2 = bcVocab.value(word).point(d) * layer1Size @@ -303,44 +338,44 @@ class Word2Vec( pos += 1 } (syn0, syn1, lwc, wc) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => - val n = syn0_1.length - val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) - val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) - blas.sscal(n, weight1, syn0_1, 1) - blas.sscal(n, weight1, syn1_1, 1) - blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) - blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) - (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) - }) + } + Iterator(model) + } + val (aggSyn0, aggSyn1, _, _) = + partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => + val n = syn0_1.length + val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) + val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) + blas.sscal(n, weight1, syn0_1, 1) + blas.sscal(n, weight1, syn1_1, 1) + blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) + blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) + (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) + } syn0Global = aggSyn0 syn1Global = aggSyn1 } newSentences.unpersist() - val wordMap = new Array[(String, Array[Float])](vocabSize) + val word2VecMap = mutable.HashMap.empty[String, Array[Float]] var i = 0 while (i < vocabSize) { val word = bcVocab.value(i).word val vector = new Array[Float](layer1Size) Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) - wordMap(i) = (word, vector) + word2VecMap += word -> vector i += 1 } - val modelRDD = sc.parallelize(wordMap, modelPartitionNum) - .partitionBy(new HashPartitioner(modelPartitionNum)) - .persist(StorageLevel.MEMORY_AND_DISK) - - new Word2VecModel(modelRDD) + + new Word2VecModel(word2VecMap.toMap) } } /** * Word2Vec model -*/ -class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable { + */ +class Word2VecModel private[mllib] ( + private val model: Map[String, Array[Float]]) extends Serializable { private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { require(v1.length == v2.length, "Vectors should have the same length") @@ -357,11 +392,12 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri * @return vector representation of word */ def transform(word: String): Vector = { - val result = model.lookup(word) - if (result.isEmpty) { - throw new IllegalStateException(s"$word not in vocabulary") + model.get(word) match { + case Some(vec) => + Vectors.dense(vec.map(_.toDouble)) + case None => + throw new IllegalStateException(s"$word not in vocabulary") } - else Vectors.dense(result(0).map(_.toDouble)) } /** @@ -392,33 +428,13 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri */ def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") - val topK = model.map { case(w, vec) => - (cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) } - .sortByKey(ascending = false) - .take(num + 1) - .map(_.swap) - .tail - - topK - } -} - -object Word2Vec{ - /** - * Train Word2Vec model - * @param input RDD of words - * @param size vector dimension - * @param startingAlpha initial learning rate - * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) - * @param numIterations number of iterations, should be smaller than or equal to parallelism - * @return Word2Vec model - */ - def train[S <: Iterable[String]]( - input: RDD[S], - size: Int, - startingAlpha: Double, - parallelism: Int = 1, - numIterations:Int = 1): Word2VecModel = { - new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input) + // TODO: optimize top-k + val fVector = vector.toArray.map(_.toFloat) + model.mapValues(vec => cosineSimilarity(fVector, vec)) + .toSeq + .sortBy(- _._2) + .take(num + 1) + .tail + .toArray } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index b5db39b68a223..e34335d89eb75 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -30,29 +30,22 @@ class Word2VecSuite extends FunSuite with LocalSparkContext { val localDoc = Seq(sentence, sentence) val doc = sc.parallelize(localDoc) .map(line => line.split(" ").toSeq) - val size = 10 - val startingAlpha = 0.025 - val window = 2 - val minCount = 2 - val num = 2 - - val model = Word2Vec.train(doc, size, startingAlpha) + val model = new Word2Vec().setVectorSize(10).setSeed(42L).fit(doc) val syms = model.findSynonyms("a", 2) - assert(syms.length == num) + assert(syms.length == 2) assert(syms(0)._1 == "b") assert(syms(1)._1 == "c") } - test("Word2VecModel") { val num = 2 - val localModel = Seq( + val word2VecMap = Map( ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) ) - val model = new Word2VecModel(sc.parallelize(localModel, 2)) + val model = new Word2VecModel(word2VecMap) val syms = model.findSynonyms("china", num) assert(syms.length == num) assert(syms(0)._1 == "taiwan") diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 2bbb9c3fca315..5ec1a8084d269 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -73,11 +73,36 @@ def predict(self, x): class LogisticRegressionWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None): - """Train a logistic regression model on the given data.""" + def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, + initialWeights=None, regParam=1.0, regType=None, intercept=False): + """ + Train a logistic regression model on the given data. + + @param data: The training data. + @param iterations: The number of iterations (default: 100). + @param step: The step parameter used in SGD + (default: 1.0). + @param miniBatchFraction: Fraction of data to be used for each SGD + iteration. + @param initialWeights: The initial weights (default: None). + @param regParam: The regularizer parameter (default: 1.0). + @param regType: The type of regularizer used for training + our model. + Allowed values: "l1" for using L1Updater, + "l2" for using + SquaredL2Updater, + "none" for no regularizer. + (default: "none") + @param intercept: Boolean parameter which indicates the use + or not of the augmented representation for + training data (i.e. whether bias features + are activated or not). + """ sc = data.context + if regType is None: + regType = "none" train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD( - d._jrdd, iterations, step, miniBatchFraction, i) + d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept) return _regression_train_wrapper(sc, train_func, LogisticRegressionModel, data, initialWeights) @@ -115,11 +140,35 @@ def predict(self, x): class SVMWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=1.0, - miniBatchFraction=1.0, initialWeights=None): - """Train a support vector machine on the given data.""" + miniBatchFraction=1.0, initialWeights=None, regType=None, intercept=False): + """ + Train a support vector machine on the given data. + + @param data: The training data. + @param iterations: The number of iterations (default: 100). + @param step: The step parameter used in SGD + (default: 1.0). + @param regParam: The regularizer parameter (default: 1.0). + @param miniBatchFraction: Fraction of data to be used for each SGD + iteration. + @param initialWeights: The initial weights (default: None). + @param regType: The type of regularizer used for training + our model. + Allowed values: "l1" for using L1Updater, + "l2" for using + SquaredL2Updater, + "none" for no regularizer. + (default: "none") + @param intercept: Boolean parameter which indicates the use + or not of the augmented representation for + training data (i.e. whether bias features + are activated or not). + """ sc = data.context + if regType is None: + regType = "none" train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD( - d._jrdd, iterations, step, regParam, miniBatchFraction, i) + d._jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept) return _regression_train_wrapper(sc, train_func, SVMModel, data, initialWeights) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 1a829c6fafe03..f1093701ddc89 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -672,12 +672,12 @@ def _infer_schema_type(obj, dataType): ByteType: (int, long), ShortType: (int, long), IntegerType: (int, long), - LongType: (int, long), + LongType: (long,), FloatType: (float,), DoubleType: (float,), DecimalType: (decimal.Decimal,), StringType: (str, unicode), - TimestampType: (datetime.datetime, datetime.time, datetime.date), + TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), StructType: (tuple, list), @@ -1042,12 +1042,15 @@ def applySchema(self, rdd, schema): [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] >>> from datetime import datetime - >>> rdd = sc.parallelize([(127, -32768, 1.0, + >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), ... {"a": 1}, (2,), [1, 2, 3], None)]) >>> schema = StructType([ - ... StructField("byte", ByteType(), False), - ... StructField("short", ShortType(), False), + ... StructField("byte1", ByteType(), False), + ... StructField("byte2", ByteType(), False), + ... StructField("short1", ShortType(), False), + ... StructField("short2", ShortType(), False), + ... StructField("int", IntegerType(), False), ... StructField("float", FloatType(), False), ... StructField("time", TimestampType(), False), ... StructField("map", @@ -1056,11 +1059,19 @@ def applySchema(self, rdd, schema): ... StructType([StructField("b", ShortType(), False)]), False), ... StructField("list", ArrayType(ByteType(), False), False), ... StructField("null", DoubleType(), True)]) - >>> srdd = sqlCtx.applySchema(rdd, schema).map( - ... lambda x: (x.byte, x.short, x.float, x.time, + >>> srdd = sqlCtx.applySchema(rdd, schema) + >>> results = srdd.map( + ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time, ... x.map["a"], x.struct.b, x.list, x.null)) - >>> srdd.collect()[0] - (127, -32768, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + >>> results.collect()[0] + (127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + + >>> srdd.registerTempTable("table2") + >>> sqlCtx.sql( + ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + + ... "float + 1.1 as float FROM table2").collect() + [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.1)] >>> rdd = sc.parallelize([(127, -32768, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2ba68cab115fb..0293d578b0b92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -48,6 +48,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool Batch("Resolution", fixedPoint, ResolveReferences :: ResolveRelations :: + ResolveSortReferences :: NewRelationInstances :: ImplicitGenerate :: StarExpansion :: @@ -113,13 +114,58 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool q transformExpressions { case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = q.resolve(name).getOrElse(u) + val result = q.resolveChildren(name).getOrElse(u) logDebug(s"Resolving $u to $result") result } } } + /** + * In many dialects of SQL is it valid to sort by attributes that are not present in the SELECT + * clause. This rule detects such queries and adds the required attributes to the original + * projection, so that they will be available during sorting. Another projection is added to + * remove these attributes after sorting. + */ + object ResolveSortReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved => + val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) + val resolved = unresolved.flatMap(child.resolveChildren) + val requiredAttributes = resolved.collect { case a: Attribute => a }.toSet + + val missingInProject = requiredAttributes -- p.output + if (missingInProject.nonEmpty) { + // Add missing attributes and then project them away after the sort. + Project(projectList, + Sort(ordering, + Project(projectList ++ missingInProject, child))) + } else { + s // Nothing we can do here. Return original plan. + } + case s @ Sort(ordering, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved => + val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) + // A small hack to create an object that will allow us to resolve any references that + // refer to named expressions that are present in the grouping expressions. + val groupingRelation = LocalRelation( + grouping.collect { case ne: NamedExpression => ne.toAttribute } + ) + + logWarning(s"Grouping expressions: $groupingRelation") + val resolved = unresolved.flatMap(groupingRelation.resolve).toSet + val missingInAggs = resolved -- a.outputSet + logWarning(s"Resolved: $resolved Missing in aggs: $missingInAggs") + if (missingInAggs.nonEmpty) { + // Add missing grouping exprs and then project them away after the sort. + Project(a.output, + Sort(ordering, + Aggregate(grouping, aggs ++ missingInAggs, child))) + } else { + s // Nothing we can do here. Return original plan. + } + } + } + /** * Replaces [[UnresolvedFunction]]s with concrete [[catalyst.expressions.Expression Expressions]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 888cb08e95f06..278569f0cb14a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -72,16 +72,29 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { def childrenResolved: Boolean = !children.exists(!_.resolved) /** - * Optionally resolves the given string to a [[NamedExpression]]. The attribute is expressed as + * Optionally resolves the given string to a [[NamedExpression]] using the input from all child + * nodes of this LogicalPlan. The attribute is expressed as * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ - def resolve(name: String): Option[NamedExpression] = { + def resolveChildren(name: String): Option[NamedExpression] = + resolve(name, children.flatMap(_.output)) + + /** + * Optionally resolves the given string to a [[NamedExpression]] based on the output of this + * LogicalPlan. The attribute is expressed as string in the following form: + * `[scope].AttributeName.[nested].[fields]...`. + */ + def resolve(name: String): Option[NamedExpression] = + resolve(name, output) + + /** Performs attribute resolution given a name and a sequence of possible attributes. */ + protected def resolve(name: String, input: Seq[Attribute]): Option[NamedExpression] = { val parts = name.split("\\.") // Collect all attributes that are output by this nodes children where either the first part // matches the name or where the first part matches the scope and the second part matches the // name. Return these matches along with any remaining parts, which represent dotted access to // struct fields. - val options = children.flatMap(_.output).flatMap { option => + val options = input.flatMap { option => // If the first part of the desired name matches a qualifier for this possible match, drop it. val remainingParts = if (option.qualifiers.contains(parts.head) && parts.size > 1) parts.drop(1) else parts @@ -89,15 +102,15 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { } options.distinct match { - case (a, Nil) :: Nil => Some(a) // One match, no nested fields, use it. + case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it. // One match, but we also need to extract the requested nested field. - case (a, nestedFields) :: Nil => + case Seq((a, nestedFields)) => a.dataType match { case StructType(fields) => Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) case _ => None // Don't know how to resolve these field references } - case Nil => None // No matches. + case Seq() => None // No matches. case ambiguousReferences => throw new TreeNodeException( this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ecd5fbaa0b094..71d338d21d0f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -491,7 +491,10 @@ class SQLContext(@transient val sparkContext: SparkContext) new java.sql.Timestamp(c.getTime().getTime()) case (c: Int, ByteType) => c.toByte + case (c: Long, ByteType) => c.toByte case (c: Int, ShortType) => c.toShort + case (c: Long, ShortType) => c.toShort + case (c: Long, IntegerType) => c.toInt case (c: Double, FloatType) => c.toFloat case (c, StringType) if !c.isInstanceOf[String] => c.toString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index c416a745739b3..7e7bb2859bbcd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -118,7 +118,7 @@ private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(BINARY) private[sql] class GenericColumnBuilder extends ComplexColumnBuilder(GENERIC) private[sql] object ColumnBuilder { - val DEFAULT_INITIAL_BUFFER_SIZE = 10 * 1024 * 104 + val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024 private[columnar] def ensureFreeSpace(orig: ByteBuffer, size: Int) = { if (orig.remaining >= size) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala index d008806eedbe1..f631ee76fcd78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala @@ -36,9 +36,9 @@ import org.apache.spark.sql.Row * }}} */ private[sql] trait NullableColumnBuilder extends ColumnBuilder { - private var nulls: ByteBuffer = _ + protected var nulls: ByteBuffer = _ + protected var nullCount: Int = _ private var pos: Int = _ - private var nullCount: Int = _ abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) { nulls = ByteBuffer.allocate(1024) @@ -78,4 +78,9 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { buffer.rewind() buffer } + + protected def buildNonNulls(): ByteBuffer = { + nulls.limit(nulls.position()).rewind() + super.build() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index 6ad12a0dcb64d..a5826bb033e41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -46,8 +46,6 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] this: NativeColumnBuilder[T] with WithCompressionSchemes => - import CompressionScheme._ - var compressionEncoders: Seq[Encoder[T]] = _ abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) { @@ -81,28 +79,32 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] } } - abstract override def build() = { - val rawBuffer = super.build() + override def build() = { + val nonNullBuffer = buildNonNulls() + val typeId = nonNullBuffer.getInt() val encoder: Encoder[T] = { val candidate = compressionEncoders.minBy(_.compressionRatio) if (isWorthCompressing(candidate)) candidate else PassThrough.encoder } - val headerSize = columnHeaderSize(rawBuffer) + // Header = column type ID + null count + null positions + val headerSize = 4 + 4 + nulls.limit() val compressedSize = if (encoder.compressedSize == 0) { - rawBuffer.limit - headerSize + nonNullBuffer.remaining() } else { encoder.compressedSize } - // Reserves 4 bytes for compression scheme ID val compressedBuffer = ByteBuffer + // Reserves 4 bytes for compression scheme ID .allocate(headerSize + 4 + compressedSize) .order(ByteOrder.nativeOrder) - - copyColumnHeader(rawBuffer, compressedBuffer) + // Write the header + .putInt(typeId) + .putInt(nullCount) + .put(nulls) logInfo(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") - encoder.compress(rawBuffer, compressedBuffer, columnType) + encoder.compress(nonNullBuffer, compressedBuffer, columnType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index ba1810dd2ae66..7797f75177893 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -67,22 +67,6 @@ private[sql] object CompressionScheme { s"Unrecognized compression scheme type ID: $typeId")) } - def copyColumnHeader(from: ByteBuffer, to: ByteBuffer) { - // Writes column type ID - to.putInt(from.getInt()) - - // Writes null count - val nullCount = from.getInt() - to.putInt(nullCount) - - // Writes null positions - var i = 0 - while (i < nullCount) { - to.putInt(from.getInt()) - i += 1 - } - } - def columnHeaderSize(columnBuffer: ByteBuffer): Int = { val header = columnBuffer.duplicate().order(ByteOrder.nativeOrder) val nullCount = header.getInt(4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala index 6d688ea95cfc0..72c19fa31d980 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala @@ -42,4 +42,3 @@ object TestCompressibleColumnBuilder { builder } } - diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 7fac90fdc596d..c6f60c18804a4 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -29,7 +29,7 @@ org.apache.spark spark-hive-thriftserver_2.10 jar - Spark Project Hive + Spark Project Hive Thrift Server http://spark.apache.org/ hive-thriftserver diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala new file mode 100644 index 0000000000000..635a9fb0d56cb --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.reflect.ClassTag + +import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ + +/** + * A collection of hive query tests where we generate the answers ourselves instead of depending on + * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is + * valid, but Hive currently cannot execute it. + */ +class SQLQuerySuite extends QueryTest { + test("ordering not in select") { + checkAnswer( + sql("SELECT key FROM src ORDER BY value"), + sql("SELECT key FROM (SELECT key, value FROM src ORDER BY value) a").collect().toSeq) + } + + test("ordering not in agg") { + checkAnswer( + sql("SELECT key FROM src GROUP BY key, value ORDER BY value"), + sql(""" + SELECT key + FROM ( + SELECT key, value + FROM src + GROUP BY key, value + ORDER BY value) a""").collect().toSeq) + } +}