From 103d8cce78533b38b4f8060b30f7f455113bc6b5 Mon Sep 17 00:00:00 2001 From: Bimal Tandel Date: Wed, 29 Jul 2015 16:54:58 -0700 Subject: [PATCH 01/50] [SPARK-8921] [MLLIB] Add @since tags to mllib.stat Author: Bimal Tandel Closes #7730 from BimalTandel/branch_spark_8921 and squashes the following commits: 3ea230a [Bimal Tandel] Spark 8921 add @since tags --- .../spark/mllib/stat/KernelDensity.scala | 5 ++++ .../stat/MultivariateOnlineSummarizer.scala | 27 +++++++++++++++++++ .../stat/MultivariateStatisticalSummary.scala | 9 +++++++ .../apache/spark/mllib/stat/Statistics.scala | 20 ++++++++++++-- .../distribution/MultivariateGaussian.scala | 9 +++++-- 5 files changed, 66 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala index 58a50f9c19f14..93a6753efd4d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -37,6 +37,7 @@ import org.apache.spark.rdd.RDD * .setBandwidth(3.0) * val densities = kd.estimate(Array(-1.0, 2.0, 5.0)) * }}} + * @since 1.4.0 */ @Experimental class KernelDensity extends Serializable { @@ -51,6 +52,7 @@ class KernelDensity extends Serializable { /** * Sets the bandwidth (standard deviation) of the Gaussian kernel (default: `1.0`). + * @since 1.4.0 */ def setBandwidth(bandwidth: Double): this.type = { require(bandwidth > 0, s"Bandwidth must be positive, but got $bandwidth.") @@ -60,6 +62,7 @@ class KernelDensity extends Serializable { /** * Sets the sample to use for density estimation. + * @since 1.4.0 */ def setSample(sample: RDD[Double]): this.type = { this.sample = sample @@ -68,6 +71,7 @@ class KernelDensity extends Serializable { /** * Sets the sample to use for density estimation (for Java users). + * @since 1.4.0 */ def setSample(sample: JavaRDD[java.lang.Double]): this.type = { this.sample = sample.rdd.asInstanceOf[RDD[Double]] @@ -76,6 +80,7 @@ class KernelDensity extends Serializable { /** * Estimates probability density function at the given array of points. + * @since 1.4.0 */ def estimate(points: Array[Double]): Array[Double] = { val sample = this.sample diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index d321cc554c1cc..62da9f2ef22a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -33,6 +33,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} * Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]] * Zero elements (including explicit zero values) are skipped when calling add(), * to have time complexity O(nnz) instead of O(n) for each column. + * @since 1.1.0 */ @DeveloperApi class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable { @@ -52,6 +53,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * * @param sample The sample in dense/sparse vector format to be added into this summarizer. * @return This MultivariateOnlineSummarizer object. + * @since 1.1.0 */ def add(sample: Vector): this.type = { if (n == 0) { @@ -107,6 +109,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * * @param other The other MultivariateOnlineSummarizer to be merged. * @return This MultivariateOnlineSummarizer object. + * @since 1.1.0 */ def merge(other: MultivariateOnlineSummarizer): this.type = { if (this.totalCnt != 0 && other.totalCnt != 0) { @@ -149,6 +152,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S this } + /** + * @since 1.1.0 + */ override def mean: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -161,6 +167,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(realMean) } + /** + * @since 1.1.0 + */ override def variance: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -183,14 +192,23 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(realVariance) } + /** + * @since 1.1.0 + */ override def count: Long = totalCnt + /** + * @since 1.1.0 + */ override def numNonzeros: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") Vectors.dense(nnz) } + /** + * @since 1.1.0 + */ override def max: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -202,6 +220,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(currMax) } + /** + * @since 1.1.0 + */ override def min: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -213,6 +234,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(currMin) } + /** + * @since 1.2.0 + */ override def normL2: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -227,6 +251,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(realMagnitude) } + /** + * @since 1.2.0 + */ override def normL1: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala index 6a364c93284af..3bb49f12289e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -21,46 +21,55 @@ import org.apache.spark.mllib.linalg.Vector /** * Trait for multivariate statistical summary of a data matrix. + * @since 1.0.0 */ trait MultivariateStatisticalSummary { /** * Sample mean vector. + * @since 1.0.0 */ def mean: Vector /** * Sample variance vector. Should return a zero vector if the sample size is 1. + * @since 1.0.0 */ def variance: Vector /** * Sample size. + * @since 1.0.0 */ def count: Long /** * Number of nonzero elements (including explicitly presented zero values) in each column. + * @since 1.0.0 */ def numNonzeros: Vector /** * Maximum value of each column. + * @since 1.0.0 */ def max: Vector /** * Minimum value of each column. + * @since 1.0.0 */ def min: Vector /** * Euclidean magnitude of each column + * @since 1.2.0 */ def normL2: Vector /** * L1 norm of each column + * @since 1.2.0 */ def normL1: Vector } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 90332028cfb3a..f84502919e381 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -32,6 +32,7 @@ import org.apache.spark.rdd.RDD /** * :: Experimental :: * API for statistical functions in MLlib. + * @since 1.1.0 */ @Experimental object Statistics { @@ -41,6 +42,7 @@ object Statistics { * * @param X an RDD[Vector] for which column-wise summary statistics are to be computed. * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics. + * @since 1.1.0 */ def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = { new RowMatrix(X).computeColumnSummaryStatistics() @@ -52,6 +54,7 @@ object Statistics { * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @return Pearson correlation matrix comparing columns in X. + * @since 1.1.0 */ def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X) @@ -68,6 +71,7 @@ object Statistics { * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. + * @since 1.1.0 */ def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) @@ -81,10 +85,14 @@ object Statistics { * @param x RDD[Double] of the same cardinality as y. * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s + * @since 1.1.0 */ def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) - /** Java-friendly version of [[corr()]] */ + /** + * Java-friendly version of [[corr()]] + * @since 1.4.1 + */ def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double = corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]]) @@ -101,10 +109,14 @@ object Statistics { * Supported: `pearson` (default), `spearman` * @return A Double containing the correlation between the two input RDD[Double]s using the * specified method. + * @since 1.1.0 */ def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) - /** Java-friendly version of [[corr()]] */ + /** + * Java-friendly version of [[corr()]] + * @since 1.4.1 + */ def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double = corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method) @@ -121,6 +133,7 @@ object Statistics { * `expected` is rescaled if the `expected` sum differs from the `observed` sum. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * @since 1.1.0 */ def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { ChiSqTest.chiSquared(observed, expected) @@ -135,6 +148,7 @@ object Statistics { * @param observed Vector containing the observed categorical counts/relative frequencies. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * @since 1.1.0 */ def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) @@ -145,6 +159,7 @@ object Statistics { * @param observed The contingency matrix (containing either counts or relative frequencies). * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * @since 1.1.0 */ def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed) @@ -157,6 +172,7 @@ object Statistics { * Real-valued features will be treated as categorical for each distinct value. * @return an array containing the ChiSquaredTestResult for every feature against the label. * The order of the elements in the returned array reflects the order of input features. + * @since 1.1.0 */ def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { ChiSqTest.chiSquaredFeatures(data) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index cf51b24ff777f..9aa7763d7890d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -32,6 +32,7 @@ import org.apache.spark.mllib.util.MLUtils * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution + * @since 1.3.0 */ @DeveloperApi class MultivariateGaussian ( @@ -60,12 +61,16 @@ class MultivariateGaussian ( */ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants - /** Returns density of this multivariate Gaussian at given point, x */ + /** Returns density of this multivariate Gaussian at given point, x + * @since 1.3.0 + */ def pdf(x: Vector): Double = { pdf(x.toBreeze) } - /** Returns the log-density of this multivariate Gaussian at given point, x */ + /** Returns the log-density of this multivariate Gaussian at given point, x + * @since 1.3.0 + */ def logpdf(x: Vector): Double = { logpdf(x.toBreeze) } From 37c2d1927cebdd19a14c054f670cb0fb9a263586 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 29 Jul 2015 18:18:29 -0700 Subject: [PATCH 02/50] [SPARK-9016] [ML] make random forest classifiers implement classification trait Implement the classification trait for RandomForestClassifiers. The plan is to use this in the future to providing thresholding for RandomForestClassifiers (as well as other classifiers that implement that trait). Author: Holden Karau Closes #7432 from holdenk/SPARK-9016-make-random-forest-classifiers-implement-classification-trait and squashes the following commits: bf22fa6 [Holden Karau] Add missing imports for testing suite e948f0d [Holden Karau] Check the prediction generation from rawprediciton 25320c3 [Holden Karau] Don't supply numClasses when not needed, assert model classes are as expected 1a67e04 [Holden Karau] Use old decission tree stuff instead 673e0c3 [Holden Karau] Merge branch 'master' into SPARK-9016-make-random-forest-classifiers-implement-classification-trait 0d15b96 [Holden Karau] FIx typo 5eafad4 [Holden Karau] add a constructor for rootnode + num classes fc6156f [Holden Karau] scala style fix 2597915 [Holden Karau] take num classes in constructor 3ccfe4a [Holden Karau] Merge in master, make pass numClasses through randomforest for training 222a10b [Holden Karau] Increase numtrees to 3 in the python test since before the two were equal and the argmax was selecting the last one 16aea1c [Holden Karau] Make tests match the new models b454a02 [Holden Karau] Make the Tree classifiers extends the Classifier base class 77b4114 [Holden Karau] Import vectors lib --- .../RandomForestClassifier.scala | 30 ++++++++++--------- .../RandomForestClassifierSuite.scala | 18 ++++++++--- python/pyspark/ml/classification.py | 4 +-- 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index fc0693f67cc2e..bc19bd6df894f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} @@ -43,7 +43,7 @@ import org.apache.spark.sql.types.DoubleType */ @Experimental final class RandomForestClassifier(override val uid: String) - extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel] + extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { def this() = this(Identifiable.randomUID("rfc")) @@ -98,7 +98,7 @@ final class RandomForestClassifier(override val uid: String) val trees = RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed) .map(_.asInstanceOf[DecisionTreeClassificationModel]) - new RandomForestClassificationModel(trees) + new RandomForestClassificationModel(trees, numClasses) } override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) @@ -125,8 +125,9 @@ object RandomForestClassifier { @Experimental final class RandomForestClassificationModel private[ml] ( override val uid: String, - private val _trees: Array[DecisionTreeClassificationModel]) - extends PredictionModel[Vector, RandomForestClassificationModel] + private val _trees: Array[DecisionTreeClassificationModel], + override val numClasses: Int) + extends ClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") @@ -135,8 +136,8 @@ final class RandomForestClassificationModel private[ml] ( * Construct a random forest classification model, with all trees weighted equally. * @param trees Component trees */ - def this(trees: Array[DecisionTreeClassificationModel]) = - this(Identifiable.randomUID("rfc"), trees) + def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) = + this(Identifiable.randomUID("rfc"), trees, numClasses) override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] @@ -153,20 +154,20 @@ final class RandomForestClassificationModel private[ml] ( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override protected def predictRaw(features: Vector): Vector = { // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 // Classifies using majority votes. // Ignore the weights since all are 1.0 for now. - val votes = mutable.Map.empty[Int, Double] + val votes = new Array[Double](numClasses) _trees.view.foreach { tree => val prediction = tree.rootNode.predict(features).toInt - votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight + votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight } - votes.maxBy(_._2)._1 + Vectors.dense(votes) } override def copy(extra: ParamMap): RandomForestClassificationModel = { - copyValues(new RandomForestClassificationModel(uid, _trees), extra) + copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra) } override def toString: String = { @@ -185,7 +186,8 @@ private[ml] object RandomForestClassificationModel { def fromOld( oldModel: OldRandomForestModel, parent: RandomForestClassifier, - categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = { + categoricalFeatures: Map[Int, Int], + numClasses: Int): RandomForestClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -193,6 +195,6 @@ private[ml] object RandomForestClassificationModel { DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") - new RandomForestClassificationModel(uid, newTrees) + new RandomForestClassificationModel(uid, newTrees, numClasses) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 1b6b69c7dc71e..ab711c8e4b215 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,13 +21,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestClassifier]]. @@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)))) + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2) ParamsSuite.checkParams(model) } @@ -167,9 +167,19 @@ private object RandomForestClassifierSuite { val newModel = rf.fit(newData) // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestClassificationModel.fromOld( - oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures) + oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures, + numClasses) TreeTests.checkEqual(oldModelAsNew, newModel) assert(newModel.hasParent) assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent) + assert(newModel.numClasses == numClasses) + val results = newModel.transform(newData) + results.select("rawPrediction", "prediction").collect().foreach { + case Row(raw: Vector, prediction: Double) => { + assert(raw.size == numClasses) + val predFromRaw = raw.toArray.zipWithIndex.maxBy(_._1)._2 + assert(predFromRaw == prediction) + } + } } } diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 89117e492846b..5a82bc286d1e8 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -299,9 +299,9 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) - >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42) + >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42) >>> model = rf.fit(td) - >>> allclose(model.treeWeights, [1.0, 1.0]) + >>> allclose(model.treeWeights, [1.0, 1.0, 1.0]) True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction From 2a9fe4a4e7acbe4c9d3b6c6e61ff46d1472ee5f4 Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 29 Jul 2015 18:23:07 -0700 Subject: [PATCH 03/50] [SPARK-6129] [MLLIB] [DOCS] Added user guide for evaluation metrics Author: sethah Closes #7655 from sethah/Working_on_6129 and squashes the following commits: 253db2d [sethah] removed number formatting from example code b769cab [sethah] rewording threshold section d5dad4d [sethah] adding some explanations of concepts to the eval metrics user guide 3a61ff9 [sethah] Removing unnecessary latex commands from metrics guide c9dd058 [sethah] Cleaning up and formatting metrics user guide section 6f31c21 [sethah] All example code for metrics section done 98813fe [sethah] Most java and python example code added. Further latex formatting 53a24fc [sethah] Adding documentations of metrics for ML algorithms to user guide --- docs/mllib-evaluation-metrics.md | 1497 ++++++++++++++++++++++++++++++ docs/mllib-guide.md | 1 + 2 files changed, 1498 insertions(+) create mode 100644 docs/mllib-evaluation-metrics.md diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md new file mode 100644 index 0000000000000..4ca0bb06b26a6 --- /dev/null +++ b/docs/mllib-evaluation-metrics.md @@ -0,0 +1,1497 @@ +--- +layout: global +title: Evaluation Metrics - MLlib +displayTitle: MLlib - Evaluation Metrics +--- + +* Table of contents +{:toc} + +Spark's MLlib comes with a number of machine learning algorithms that can be used to learn from and make predictions +on data. When these algorithms are applied to build machine learning models, there is a need to evaluate the performance +of the model on some criteria, which depends on the application and its requirements. Spark's MLlib also provides a +suite of metrics for the purpose of evaluating the performance of machine learning models. + +Specific machine learning algorithms fall under broader types of machine learning applications like classification, +regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those +metrics that are currently available in Spark's MLlib are detailed in this section. + +## Classification model evaluation + +While there are many different types of classification algorithms, the evaluation of classification models all share +similar principles. In a [supervised classification problem](https://en.wikipedia.org/wiki/Statistical_classification), +there exists a true output and a model-generated predicted output for each data point. For this reason, the results for +each data point can be assigned to one of four categories: + +* True Positive (TP) - label is positive and prediction is also positive +* True Negative (TN) - label is negative and prediction is also negative +* False Positive (FP) - label is negative but prediction is positive +* False Negative (FN) - label is positive but prediction is negative + +These four numbers are the building blocks for most classifier evaluation metrics. A fundamental point when considering +classifier evaluation is that pure accuracy (i.e. was the prediction correct or incorrect) is not generally a good metric. The +reason for this is because a dataset may be highly unbalanced. For example, if a model is designed to predict fraud from +a dataset where 95% of the data points are _not fraud_ and 5% of the data points are _fraud_, then a naive classifier +that predicts _not fraud_, regardless of input, will be 95% accurate. For this reason, metrics like +[precision and recall](https://en.wikipedia.org/wiki/Precision_and_recall) are typically used because they take into +account the *type* of error. In most applications there is some desired balance between precision and recall, which can +be captured by combining the two into a single metric, called the [F-measure](https://en.wikipedia.org/wiki/F1_score). + +### Binary classification + +[Binary classifiers](https://en.wikipedia.org/wiki/Binary_classification) are used to separate the elements of a given +dataset into one of two possible groups (e.g. fraud or not fraud) and is a special case of multiclass classification. +Most binary classification metrics can be generalized to multiclass classification metrics. + +#### Threshold tuning + +It is import to understand that many classification models actually output a "score" (often times a probability) for +each class, where a higher score indicates higher likelihood. In the binary case, the model may output a probability for +each class: $P(Y=1|X)$ and $P(Y=0|X)$. Instead of simply taking the higher probability, there may be some cases where +the model might need to be tuned so that it only predicts a class when the probability is very high (e.g. only block a +credit card transaction if the model predicts fraud with >90% probability). Therefore, there is a prediction *threshold* +which determines what the predicted class will be based on the probabilities that the model outputs. + +Tuning the prediction threshold will change the precision and recall of the model and is an important part of model +optimization. In order to visualize how precision, recall, and other metrics change as a function of the threshold it is +common practice to plot competing metrics against one another, parameterized by threshold. A P-R curve plots (precision, +recall) points for different threshold values, while a +[receiver operating characteristic](https://en.wikipedia.org/wiki/Receiver_operating_characteristic), or ROC, curve +plots (recall, false positive rate) points. + +**Available metrics** + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinition
Precision (Postive Predictive Value)$PPV=\frac{TP}{TP + FP}$
Recall (True Positive Rate)$TPR=\frac{TP}{P}=\frac{TP}{TP + FN}$
F-measure$F(\beta) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV \cdot TPR} + {\beta^2 \cdot PPV + TPR}\right)$
Receiver Operating Characteristic (ROC)$FPR(T)=\int^\infty_{T} P_0(T)\,dT \\ TPR(T)=\int^\infty_{T} P_1(T)\,dT$
Area Under ROC Curve$AUROC=\int^1_{0} \frac{TP}{P} d\left(\frac{FP}{N}\right)$
Area Under Precision-Recall Curve$AUPRC=\int^1_{0} \frac{TP}{TP+FP} d\left(\frac{TP}{P}\right)$
+ + +**Examples** + +
+The following code snippets illustrate how to load a sample dataset, train a binary classification algorithm on the +data, and evaluate the performance of the algorithm by several binary evaluation metrics. + +
+ +{% highlight scala %} +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// Load training data in LIBSVM format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + +// Split data into training (60%) and test (40%) +val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) +training.cache() + +// Run training algorithm to build the model +val model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training) + +// Clear the prediction threshold so the model will return probabilities +model.clearThreshold + +// Compute raw scores on the test set +val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) +} + +// Instantiate metrics object +val metrics = new BinaryClassificationMetrics(predictionAndLabels) + +// Precision by threshold +val precision = metrics.precisionByThreshold +precision.foreach { case (t, p) => + println(s"Threshold: $t, Precision: $p") +} + +// Recall by threshold +val recall = metrics.precisionByThreshold +recall.foreach { case (t, r) => + println(s"Threshold: $t, Recall: $r") +} + +// Precision-Recall Curve +val PRC = metrics.pr + +// F-measure +val f1Score = metrics.fMeasureByThreshold +f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 1") +} + +val beta = 0.5 +val fScore = metrics.fMeasureByThreshold(beta) +f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 0.5") +} + +// AUPRC +val auPRC = metrics.areaUnderPR +println("Area under precision-recall curve = " + auPRC) + +// Compute thresholds used in ROC and PR curves +val thresholds = precision.map(_._1) + +// ROC Curve +val roc = metrics.roc + +// AUROC +val auROC = metrics.areaUnderROC +println("Area under ROC = " + auROC) + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class BinaryClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_binary_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training.rdd()); + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); + + // Precision by threshold + JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); + System.out.println("Precision by threshold: " + precision.toArray()); + + // Recall by threshold + JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); + System.out.println("Recall by threshold: " + recall.toArray()); + + // F Score by threshold + JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); + System.out.println("F1 Score by threshold: " + f1Score.toArray()); + + JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); + System.out.println("F2 Score by threshold: " + f2Score.toArray()); + + // Precision-recall curve + JavaRDD> prc = metrics.pr().toJavaRDD(); + System.out.println("Precision-recall curve: " + prc.toArray()); + + // Thresholds + JavaRDD thresholds = precision.map( + new Function, Double>() { + public Double call (Tuple2 t) { + return new Double(t._1().toString()); + } + } + ); + + // ROC Curve + JavaRDD> roc = metrics.roc().toJavaRDD(); + System.out.println("ROC curve: " + roc.toArray()); + + // AUPRC + System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); + + // AUROC + System.out.println("Area under ROC = " + metrics.areaUnderROC()); + + // Save and load model + model.save(sc, "myModelPath"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.evaluation import BinaryClassificationMetrics +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import MLUtils + +# Several of the methods available in scala are currently missing from pyspark + +# Load training data in LIBSVM format +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + +# Split data into training (60%) and test (40%) +training, test = data.randomSplit([0.6, 0.4], seed = 11L) +training.cache() + +# Run training algorithm to build the model +model = LogisticRegressionWithLBFGS.train(training) + +# Compute raw scores on the test set +predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + +# Instantiate metrics object +metrics = BinaryClassificationMetrics(predictionAndLabels) + +# Area under precision-recall curve +print "Area under PR = %s" % metrics.areaUnderPR + +# Area under ROC curve +print "Area under ROC = %s" % metrics.areaUnderROC + +{% endhighlight %} + +
+
+ + +### Multiclass classification + +A [multiclass classification](https://en.wikipedia.org/wiki/Multiclass_classification) describes a classification +problem where there are $M \gt 2$ possible labels for each data point (the case where $M=2$ is the binary +classification problem). For example, classifying handwriting samples to the digits 0 to 9, having 10 possible classes. + +For multiclass metrics, the notion of positives and negatives is slightly different. Predictions and labels can still +be positive or negative, but they must be considered under the context of a particular class. Each label and prediction +take on the value of one of the multiple classes and so they are said to be positive for their particular class and negative +for all other classes. So, a true positive occurs whenever the prediction and the label match, while a true negative +occurs when neither the prediction nor the label take on the value of a given class. By this convention, there can be +multiple true negatives for a given data sample. The extension of false negatives and false positives from the former +definitions of positive and negative labels is straightforward. + +#### Label based metrics + +Opposed to binary classification where there are only two possible labels, multiclass classification problems have many +possible labels and so the concept of label-based metrics is introduced. Overall precision measures precision across all +labels - the number of times any class was predicted correctly (true positives) normalized by the number of data +points. Precision by label considers only one class, and measures the number of time a specific label was predicted +correctly normalized by the number of times that label appears in the output. + +**Available metrics** + +Define the class, or label, set as + +$$L = \{\ell_0, \ell_1, \ldots, \ell_{M-1} \} $$ + +The true output vector $\mathbf{y}$ consists of $N$ elements + +$$\mathbf{y}_0, \mathbf{y}_1, \ldots, \mathbf{y}_{N-1} \in L $$ + +A multiclass prediction algorithm generates a prediction vector $\hat{\mathbf{y}}$ of $N$ elements + +$$\hat{\mathbf{y}}_0, \hat{\mathbf{y}}_1, \ldots, \hat{\mathbf{y}}_{N-1} \in L $$ + +For this section, a modified delta function $\hat{\delta}(x)$ will prove useful + +$$\hat{\delta}(x) = \begin{cases}1 & \text{if $x = 0$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinition
Confusion Matrix + $C_{ij} = \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_i) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_j)\\ \\ + \left( \begin{array}{ccc} + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots & + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N) \\ + \vdots & \ddots & \vdots \\ + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots & + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N) + \end{array} \right)$ +
Overall Precision$PPV = \frac{TP}{TP + FP} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - + \mathbf{y}_i\right)$
Overall Recall$TPR = \frac{TP}{TP + FN} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - + \mathbf{y}_i\right)$
Overall F1-measure$F1 = 2 \cdot \left(\frac{PPV \cdot TPR} + {PPV + TPR}\right)$
Precision by label$PPV(\ell) = \frac{TP}{TP + FP} = + \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)} + {\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell)}$
Recall by label$TPR(\ell)=\frac{TP}{P} = + \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)} + {\sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i - \ell)}$
F-measure by label$F(\beta, \ell) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)} + {\beta^2 \cdot PPV(\ell) + TPR(\ell)}\right)$
Weighted precision$PPV_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} PPV(\ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
Weighted recall$TPR_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} TPR(\ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
Weighted F-measure$F_{w}(\beta)= \frac{1}{N} \sum\nolimits_{\ell \in L} F(\beta, \ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
+ +**Examples** + +
+The following code snippets illustrate how to load a sample dataset, train a multiclass classification algorithm on +the data, and evaluate the performance of the algorithm by several multiclass classification evaluation metrics. + +
+ +{% highlight scala %} +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// Load training data in LIBSVM format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + +// Split data into training (60%) and test (40%) +val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) +training.cache() + +// Run training algorithm to build the model +val model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training) + +// Compute raw scores on the test set +val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) +} + +// Instantiate metrics object +val metrics = new MulticlassMetrics(predictionAndLabels) + +// Confusion matrix +println("Confusion matrix:") +println(metrics.confusionMatrix) + +// Overall Statistics +val precision = metrics.precision +val recall = metrics.recall // same as true positive rate +val f1Score = metrics.fMeasure +println("Summary Statistics") +println(s"Precision = $precision") +println(s"Recall = $recall") +println(s"F1 Score = $f1Score") + +// Precision by label +val labels = metrics.labels +labels.foreach { l => + println(s"Precision($l) = " + metrics.precision(l)) +} + +// Recall by label +labels.foreach { l => + println(s"Recall($l) = " + metrics.recall(l)) +} + +// False positive rate by label +labels.foreach { l => + println(s"FPR($l) = " + metrics.falsePositiveRate(l)) +} + +// F-measure by label +labels.foreach { l => + println(s"F1-Score($l) = " + metrics.fMeasure(l)) +} + +// Weighted stats +println(s"Weighted precision: ${metrics.weightedPrecision}") +println(s"Weighted recall: ${metrics.weightedRecall}") +println(s"Weighted F1 score: ${metrics.weightedFMeasure}") +println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class MulticlassClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multiclass Classification Metrics"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_multiclass_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + + // Confusion matrix + Matrix confusion = metrics.confusionMatrix(); + System.out.println("Confusion matrix: \n" + confusion); + + // Overall statistics + System.out.println("Precision = " + metrics.precision()); + System.out.println("Recall = " + metrics.recall()); + System.out.println("F1 Score = " + metrics.fMeasure()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length; i++) { + System.out.format("Class %f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); + System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); + System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i])); + } + + //Weighted stats + System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); + System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); + System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); + System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); + + // Save and load model + model.save(sc, "myModelPath"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.util import MLUtils +from pyspark.mllib.evaluation import MulticlassMetrics + +# Load training data in LIBSVM format +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + +# Split data into training (60%) and test (40%) +training, test = data.randomSplit([0.6, 0.4], seed = 11L) +training.cache() + +# Run training algorithm to build the model +model = LogisticRegressionWithLBFGS.train(training, numClasses=3) + +# Compute raw scores on the test set +predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + +# Instantiate metrics object +metrics = MulticlassMetrics(predictionAndLabels) + +# Overall statistics +precision = metrics.precision() +recall = metrics.recall() +f1Score = metrics.fMeasure() +print "Summary Stats" +print "Precision = %s" % precision +print "Recall = %s" % recall +print "F1 Score = %s" % f1Score + +# Statistics by class +labels = data.map(lambda lp: lp.label).distinct().collect() +for label in sorted(labels): + print "Class %s precision = %s" % (label, metrics.precision(label)) + print "Class %s recall = %s" % (label, metrics.recall(label)) + print "Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0)) + +# Weighted stats +print "Weighted recall = %s" % metrics.weightedRecall +print "Weighted precision = %s" % metrics.weightedPrecision +print "Weighted F(1) Score = %s" % metrics.weightedFMeasure() +print "Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5) +print "Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate +{% endhighlight %} + +
+
+ +### Multilabel classification + +A [multilabel classification](https://en.wikipedia.org/wiki/Multi-label_classification) problem involves mapping +each sample in a dataset to a set of class labels. In this type of classification problem, the labels are not +mutually exclusive. For example, when classifying a set of news articles into topics, a single article might be both +science and politics. + +Because the labels are not mutually exclusive, the predictions and true labels are now vectors of label *sets*, rather +than vectors of labels. Multilabel metrics, therefore, extend the fundamental ideas of precision, recall, etc. to +operations on sets. For example, a true positive for a given class now occurs when that class exists in the predicted +set and it exists in the true label set, for a specific data point. + +**Available metrics** + +Here we define a set $D$ of $N$ documents + +$$D = \left\{d_0, d_1, ..., d_{N-1}\right\}$$ + +Define $L_0, L_1, ..., L_{N-1}$ to be a family of label sets and $P_0, P_1, ..., P_{N-1}$ +to be a family of prediction sets where $L_i$ and $P_i$ are the label set and prediction set, respectively, that +correspond to document $d_i$. + +The set of all unique labels is given by + +$$L = \bigcup_{k=0}^{N-1} L_k$$ + +The following definition of indicator function $I_A(x)$ on a set $A$ will be necessary + +$$I_A(x) = \begin{cases}1 & \text{if $x \in A$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinition
Precision$\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|P_i \cap L_i\right|}{\left|P_i\right|}$
Recall$\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|L_i \cap P_i\right|}{\left|L_i\right|}$
Accuracy + $\frac{1}{N} \sum_{i=0}^{N - 1} \frac{\left|L_i \cap P_i \right|} + {\left|L_i\right| + \left|P_i\right| - \left|L_i \cap P_i \right|}$ +
Precision by label$PPV(\ell)=\frac{TP}{TP + FP}= + \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)} + {\sum_{i=0}^{N-1} I_{P_i}(\ell)}$
Recall by label$TPR(\ell)=\frac{TP}{P}= + \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)} + {\sum_{i=0}^{N-1} I_{L_i}(\ell)}$
F1-measure by label$F1(\ell) = 2 + \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)} + {PPV(\ell) + TPR(\ell)}\right)$
Hamming Loss + $\frac{1}{N \cdot \left|L\right|} \sum_{i=0}^{N - 1} \left|L_i\right| + \left|P_i\right| - 2\left|L_i + \cap P_i\right|$ +
Subset Accuracy$\frac{1}{N} \sum_{i=0}^{N-1} I_{\{L_i\}}(P_i)$
F1 Measure$\frac{1}{N} \sum_{i=0}^{N-1} 2 \frac{\left|P_i \cap L_i\right|}{\left|P_i\right| \cdot \left|L_i\right|}$
Micro precision$\frac{TP}{TP + FP}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|} + {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|P_i - L_i\right|}$
Micro recall$\frac{TP}{TP + FN}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|} + {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right|}$
Micro F1 Measure + $2 \cdot \frac{TP}{2 \cdot TP + FP + FN}=2 \cdot \frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|}{2 \cdot + \sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right| + \sum_{i=0}^{N-1} + \left|P_i - L_i\right|}$ +
+ +**Examples** + +The following code snippets illustrate how to evaluate the performance of a multilabel classifer. The examples +use the fake prediction and label data for multilabel classification that is shown below. + +Document predictions: + +* doc 0 - predict 0, 1 - class 0, 2 +* doc 1 - predict 0, 2 - class 0, 1 +* doc 2 - predict none - class 0 +* doc 3 - predict 2 - class 2 +* doc 4 - predict 2, 0 - class 2, 0 +* doc 5 - predict 0, 1, 2 - class 0, 1 +* doc 6 - predict 1 - class 1, 2 + +Predicted classes: + +* class 0 - doc 0, 1, 4, 5 (total 4) +* class 1 - doc 0, 5, 6 (total 3) +* class 2 - doc 1, 3, 4, 5 (total 4) + +True classes: + +* class 0 - doc 0, 1, 2, 4, 5 (total 5) +* class 1 - doc 1, 5, 6 (total 3) +* class 2 - doc 0, 3, 4, 6 (total 4) + +
+ +
+ +{% highlight scala %} +import org.apache.spark.mllib.evaluation.MultilabelMetrics +import org.apache.spark.rdd.RDD; + +val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( + Seq((Array(0.0, 1.0), Array(0.0, 2.0)), + (Array(0.0, 2.0), Array(0.0, 1.0)), + (Array(), Array(0.0)), + (Array(2.0), Array(2.0)), + (Array(2.0, 0.0), Array(2.0, 0.0)), + (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), + (Array(1.0), Array(1.0, 2.0))), 2) + +// Instantiate metrics object +val metrics = new MultilabelMetrics(scoreAndLabels) + +// Summary stats +println(s"Recall = ${metrics.recall}") +println(s"Precision = ${metrics.precision}") +println(s"F1 measure = ${metrics.f1Measure}") +println(s"Accuracy = ${metrics.accuracy}") + +// Individual label stats +metrics.labels.foreach(label => println(s"Class $label precision = ${metrics.precision(label)}")) +metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) +metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) + +// Micro stats +println(s"Micro recall = ${metrics.microRecall}") +println(s"Micro precision = ${metrics.microPrecision}") +println(s"Micro F1 measure = ${metrics.microF1Measure}") + +// Hamming loss +println(s"Hamming loss = ${metrics.hammingLoss}") + +// Subset accuracy +println(s"Subset accuracy = ${metrics.subsetAccuracy}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.mllib.evaluation.MultilabelMetrics; +import org.apache.spark.SparkConf; +import java.util.Arrays; +import java.util.List; + +public class MultilabelClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics"); + JavaSparkContext sc = new JavaSparkContext(conf); + + List> data = Arrays.asList( + new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), + new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{}, new double[]{0.0}), + new Tuple2(new double[]{2.0}, new double[]{2.0}), + new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), + new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) + ); + JavaRDD> scoreAndLabels = sc.parallelize(data); + + // Instantiate metrics object + MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); + + // Summary stats + System.out.format("Recall = %f\n", metrics.recall()); + System.out.format("Precision = %f\n", metrics.precision()); + System.out.format("F1 measure = %f\n", metrics.f1Measure()); + System.out.format("Accuracy = %f\n", metrics.accuracy()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length - 1; i++) { + System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); + System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); + System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i])); + } + + // Micro stats + System.out.format("Micro recall = %f\n", metrics.microRecall()); + System.out.format("Micro precision = %f\n", metrics.microPrecision()); + System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); + + // Hamming loss + System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); + + // Subset accuracy + System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); + + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.evaluation import MultilabelMetrics + +scoreAndLabels = sc.parallelize([ + ([0.0, 1.0], [0.0, 2.0]), + ([0.0, 2.0], [0.0, 1.0]), + ([], [0.0]), + ([2.0], [2.0]), + ([2.0, 0.0], [2.0, 0.0]), + ([0.0, 1.0, 2.0], [0.0, 1.0]), + ([1.0], [1.0, 2.0])]) + +# Instantiate metrics object +metrics = MultilabelMetrics(scoreAndLabels) + +# Summary stats +print "Recall = %s" % metrics.recall() +print "Precision = %s" % metrics.precision() +print "F1 measure = %s" % metrics.f1Measure() +print "Accuracy = %s" % metrics.accuracy + +# Individual label stats +labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() +for label in labels: + print "Class %s precision = %s" % (label, metrics.precision(label)) + print "Class %s recall = %s" % (label, metrics.recall(label)) + print "Class %s F1 Measure = %s" % (label, metrics.f1Measure(label)) + +# Micro stats +print "Micro precision = %s" % metrics.microPrecision +print "Micro recall = %s" % metrics.microRecall +print "Micro F1 measure = %s" % metrics.microF1Measure + +# Hamming loss +print "Hamming loss = %s" % metrics.hammingLoss + +# Subset accuracy +print "Subset accuracy = %s" % metrics.subsetAccuracy + +{% endhighlight %} + +
+
+ +### Ranking systems + +The role of a ranking algorithm (often thought of as a [recommender system](https://en.wikipedia.org/wiki/Recommender_system)) +is to return to the user a set of relevant items or documents based on some training data. The definition of relevance +may vary and is usually application specific. Ranking system metrics aim to quantify the effectiveness of these +rankings or recommendations in various contexts. Some metrics compare a set of recommended documents to a ground truth +set of relevant documents, while other metrics may incorporate numerical ratings explicitly. + +**Available metrics** + +A ranking system usually deals with a set of $M$ users + +$$U = \left\{u_0, u_1, ..., u_{M-1}\right\}$$ + +Each user ($u_i$) having a set of $N$ ground truth relevant documents + +$$D_i = \left\{d_0, d_1, ..., d_{N-1}\right\}$$ + +And a list of $Q$ recommended documents, in order of decreasing relevance + +$$R_i = \left[r_0, r_1, ..., r_{Q-1}\right]$$ + +The goal of the ranking system is to produce the most relevant set of documents for each user. The relevance of the +sets and the effectiveness of the algorithms can be measured using the metrics listed below. + +It is necessary to define a function which, provided a recommended document and a set of ground truth relevant +documents, returns a relevance score for the recommended document. + +$$rel_D(r) = \begin{cases}1 & \text{if $r \in D$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinitionNotes
+ Precision at k + + $p(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{k} \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} rel_{D_i}(R_i(j))}$ + + Precision at k is a measure of + how many of the first k recommended documents are in the set of true relevant documents averaged across all + users. In this metric, the order of the recommendations is not taken into account. +
Mean Average Precision + $MAP=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{\left|D_i\right|} \sum_{j=0}^{Q-1} \frac{rel_{D_i}(R_i(j))}{j + 1}}$ + + MAP is a measure of how + many of the recommended documents are in the set of true relevant documents, where the + order of the recommendations is taken into account (i.e. penalty for highly relevant documents is higher). +
Normalized Discounted Cumulative Gain + $NDCG(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{IDCG(D_i, k)}\sum_{j=0}^{n-1} + \frac{rel_{D_i}(R_i(j))}{\text{ln}(j+1)}} \\ + \text{Where} \\ + \hspace{5 mm} n = \text{min}\left(\text{max}\left(|R_i|,|D_i|\right),k\right) \\ + \hspace{5 mm} IDCG(D, k) = \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} \frac{1}{\text{ln}(j+1)}$ + + NDCG at k is a + measure of how many of the first k recommended documents are in the set of true relevant documents averaged + across all users. In contrast to precision at k, this metric takes into account the order of the recommendations + (documents are assumed to be in order of decreasing relevance). +
+ +**Examples** + +The following code snippets illustrate how to load a sample dataset, train an alternating least squares recommendation +model on the data, and evaluate the performance of the recommender by several ranking metrics. A brief summary of the +methodology is provided below. + +MovieLens ratings are on a scale of 1-5: + + * 5: Must see + * 4: Will enjoy + * 3: It's okay + * 2: Fairly bad + * 1: Awful + +So we should not recommend a movie if the predicted rating is less than 3. +To map ratings to confidence scores, we use: + + * 5 -> 2.5 + * 4 -> 1.5 + * 3 -> 0.5 + * 2 -> -0.5 + * 1 -> -1.5. + +This mappings means unobserved entries are generally between It's okay and Fairly bad. The semantics of 0 in this +expanded world of non-positive weights are "the same as never having interacted at all." + +
+ +
+ +{% highlight scala %} +import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} +import org.apache.spark.mllib.recommendation.{ALS, Rating} + +// Read in the ratings data +val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => + val fields = line.split("::") + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) +}.cache() + +// Map ratings to 1 or 0, 1 indicating a movie that should be recommended +val binarizedRatings = ratings.map(r => Rating(r.user, r.product, if (r.rating > 0) 1.0 else 0.0)).cache() + +// Summarize ratings +val numRatings = ratings.count() +val numUsers = ratings.map(_.user).distinct().count() +val numMovies = ratings.map(_.product).distinct().count() +println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + +// Build the model +val numIterations = 10 +val rank = 10 +val lambda = 0.01 +val model = ALS.train(ratings, rank, numIterations, lambda) + +// Define a function to scale ratings from 0 to 1 +def scaledRating(r: Rating): Rating = { + val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) + Rating(r.user, r.product, scaledRating) +} + +// Get sorted top ten predictions for each user and then scale from [0, 1] +val userRecommended = model.recommendProductsForUsers(10).map{ case (user, recs) => + (user, recs.map(scaledRating)) +} + +// Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document +// Compare with top ten most relevant documents +val userMovies = binarizedRatings.groupBy(_.user) +val relevantDocuments = userMovies.join(userRecommended).map{ case (user, (actual, predictions)) => + (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) +} + +// Instantiate metrics object +val metrics = new RankingMetrics(relevantDocuments) + +// Precision at K +Array(1, 3, 5).foreach{ k => + println(s"Precision at $k = ${metrics.precisionAt(k)}") +} + +// Mean average precision +println(s"Mean average precision = ${metrics.meanAveragePrecision}") + +// Normalized discounted cumulative gain +Array(1, 3, 5).foreach{ k => + println(s"NDCG at $k = ${metrics.ndcgAt(k)}") +} + +// Get predictions for each data point +val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, r.product), r.rating)) +val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) +val predictionsAndLabels = allPredictions.join(allRatings).map{ case ((user, product), (predicted, actual)) => + (predicted, actual) +} + +// Get the RMSE using regression metrics +val regressionMetrics = new RegressionMetrics(predictionsAndLabels) +println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") + +// R-squared +println(s"R-squared = ${regressionMetrics.r2}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.Function; +import java.util.*; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.mllib.evaluation.RankingMetrics; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.Rating; + +// Read in the ratings data +public class Ranking { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Ranking Metrics"); + JavaSparkContext sc = new JavaSparkContext(conf); + String path = "data/mllib/sample_movielens_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String line) { + String[] parts = line.split("::"); + return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double.parseDouble(parts[2]) - 2.5); + } + } + ); + ratings.cache(); + + // Train an ALS model + final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); + + // Get top 10 recommendations for every user and scale ratings from 0 to 1 + JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); + JavaRDD> userRecsScaled = userRecs.map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 t) { + Rating[] scaledRatings = new Rating[t._2().length]; + for (int i = 0; i < scaledRatings.length; i++) { + double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); + scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); + } + return new Tuple2(t._1(), scaledRatings); + } + } + ); + JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + JavaRDD binarizedRatings = ratings.map( + new Function() { + public Rating call(Rating r) { + double binaryRating; + if (r.rating() > 0.0) { + binaryRating = 1.0; + } + else { + binaryRating = 0.0; + } + return new Rating(r.user(), r.product(), binaryRating); + } + } + ); + + // Group ratings by common user + JavaPairRDD> userMovies = binarizedRatings.groupBy( + new Function() { + public Object call(Rating r) { + return r.user(); + } + } + ); + + // Get true relevant documents from all user ratings + JavaPairRDD> userMoviesList = userMovies.mapValues( + new Function, List>() { + public List call(Iterable docs) { + List products = new ArrayList(); + for (Rating r : docs) { + if (r.rating() > 0.0) { + products.add(r.product()); + } + } + return products; + } + } + ); + + // Extract the product id from each recommendation + JavaPairRDD> userRecommendedList = userRecommended.mapValues( + new Function>() { + public List call(Rating[] docs) { + List products = new ArrayList(); + for (Rating r : docs) { + products.add(r.product()); + } + return products; + } + } + ); + JavaRDD, List>> relevantDocs = userMoviesList.join(userRecommendedList).values(); + + // Instantiate the metrics object + RankingMetrics metrics = RankingMetrics.of(relevantDocs); + + // Precision and NDCG at k + Integer[] kVector = {1, 3, 5}; + for (Integer k : kVector) { + System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); + System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); + } + + // Mean average precision + System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); + + // Evaluate the model using numerical ratings and regression metrics + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r){ + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r){ + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + + // Create regression metrics object + RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); + + // Root mean squared error + System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R-squared = %f\n", regressionMetrics.r2()); + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.recommendation import ALS, Rating +from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics + +# Read in the ratings data +lines = sc.textFile("data/mllib/sample_movielens_data.txt") + +def parseLine(line): + fields = line.split("::") + return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) +ratings = lines.map(lambda r: parseLine(r)) + +# Train a model on to predict user-product ratings +model = ALS.train(ratings, 10, 10, 0.01) + +# Get predicted ratings on all existing user-product pairs +testData = ratings.map(lambda p: (p.user, p.product)) +predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) + +ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) +scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) + +# Instantiate regression metrics to compare predicted and actual ratings +metrics = RegressionMetrics(scoreAndLabels) + +# Root mean sqaured error +print "RMSE = %s" % metrics.rootMeanSquaredError + +# R-squared +print "R-squared = %s" % metrics.r2 + +{% endhighlight %} + +
+
+ +## Regression model evaluation + +[Regression analysis](https://en.wikipedia.org/wiki/Regression_analysis) is used when predicting a continuous output +variable from a number of independent variables. + +**Available metrics** + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinition
Mean Squared Error (MSE)$MSE = \frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}$
Root Mean Squared Error (RMSE)$RMSE = \sqrt{\frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}}$
Mean Absoloute Error (MAE)$MAE=\sum_{i=0}^{N-1} \left|\mathbf{y}_i - \hat{\mathbf{y}}_i\right|$
Coefficient of Determination $(R^2)$$R^2=1 - \frac{MSE}{\text{VAR}(\mathbf{y}) \cdot (N-1)}=1-\frac{\sum_{i=0}^{N-1} + (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{\sum_{i=0}^{N-1}(\mathbf{y}_i-\bar{\mathbf{y}})^2}$
Explained Variance$1 - \frac{\text{VAR}(\mathbf{y} - \mathbf{\hat{y}})}{\text{VAR}(\mathbf{y})}$
+ +**Examples** + +
+The following code snippets illustrate how to load a sample dataset, train a linear regression algorithm on the data, +and evaluate the performance of the algorithm by several regression metrics. + +
+ +{% highlight scala %} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.mllib.util.MLUtils + +// Load the data +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() + +// Build the model +val numIterations = 100 +val model = LinearRegressionWithSGD.train(data, numIterations) + +// Get predictions +val valuesAndPreds = data.map{ point => + val prediction = model.predict(point.features) + (prediction, point.label) +} + +// Instantiate metrics object +val metrics = new RegressionMetrics(valuesAndPreds) + +// Squared error +println(s"MSE = ${metrics.meanSquaredError}") +println(s"RMSE = ${metrics.rootMeanSquaredError}") + +// R-squared +println(s"R-squared = ${metrics.r2}") + +// Mean absolute error +println(s"MAE = ${metrics.meanAbsoluteError}") + +// Explained variance +println(s"Explained variance = ${metrics.explainedVariance}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.regression.LinearRegressionModel; +import org.apache.spark.mllib.regression.LinearRegressionWithSGD; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.SparkConf; + +public class LinearRegression { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/sample_linear_regression_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public LabeledPoint call(String line) { + String[] parts = line.split(" "); + double[] v = new double[parts.length - 1]; + for (int i = 1; i < parts.length - 1; i++) + v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + } + } + ); + parsedData.cache(); + + // Building the model + int numIterations = 100; + final LinearRegressionModel model = + LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); + + // Evaluate model on training examples and compute training error + JavaRDD> valuesAndPreds = parsedData.map( + new Function>() { + public Tuple2 call(LabeledPoint point) { + double prediction = model.predict(point.features()); + return new Tuple2(prediction, point.label()); + } + } + ); + + // Instantiate metrics object + RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); + + // Squared error + System.out.format("MSE = %f\n", metrics.meanSquaredError()); + System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R Squared = %f\n", metrics.r2()); + + // Mean absolute error + System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); + + // Explained variance + System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); + + // Save and load model + model.save(sc.sc(), "myModelPath"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.evaluation import RegressionMetrics +from pyspark.mllib.linalg import DenseVector + +# Load and parse the data +def parsePoint(line): + values = line.split() + return LabeledPoint(float(values[0]), DenseVector([float(x.split(':')[1]) for x in values[1:]])) + +data = sc.textFile("data/mllib/sample_linear_regression_data.txt") +parsedData = data.map(parsePoint) + +# Build the model +model = LinearRegressionWithSGD.train(parsedData) + +# Get predictions +valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) + +# Instantiate metrics object +metrics = RegressionMetrics(valuesAndPreds) + +# Squared Error +print "MSE = %s" % metrics.meanSquaredError +print "RMSE = %s" % metrics.rootMeanSquaredError + +# R-squared +print "R-squared = %s" % metrics.r2 + +# Mean absolute error +print "MAE = %s" % metrics.meanAbsoluteError + +# Explained variance +print "Explained variance = %s" % metrics.explainedVariance + +{% endhighlight %} + +
+
\ No newline at end of file diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index d2d1cc93fe006..eea864eacf7c4 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -48,6 +48,7 @@ This lists functionality included in `spark.mllib`, the main MLlib API. * [Feature extraction and transformation](mllib-feature-extraction.html) * [Frequent pattern mining](mllib-frequent-pattern-mining.html) * FP-growth +* [Evaluation Metrics](mllib-evaluation-metrics.html) * [Optimization (developer)](mllib-optimization.html) * stochastic gradient descent * limited-memory BFGS (L-BFGS) From a200e64561c8803731578267df16906f6773cbea Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Wed, 29 Jul 2015 19:02:15 -0700 Subject: [PATCH 04/50] [SPARK-9440] [MLLIB] Add hyperparameters to LocalLDAModel save/load jkbradley MechCoder Resolves blocking issue for SPARK-6793. Please review after #7705 is merged. Author: Feynman Liang Closes #7757 from feynmanliang/SPARK-9940-localSaveLoad and squashes the following commits: d0d8cf4 [Feynman Liang] Fix thisClassName 0f30109 [Feynman Liang] Fix tests after changing LDAModel public API dc61981 [Feynman Liang] Add hyperparams to LocalLDAModel save/load --- .../spark/mllib/clustering/LDAModel.scala | 40 +++++++++++++------ .../spark/mllib/clustering/LDASuite.scala | 6 ++- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 059b52ef20a98..ece28848aa02c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -215,7 +215,8 @@ class LocalLDAModel private[clustering] ( override protected def formatVersion = "1.0" override def save(sc: SparkContext, path: String): Unit = { - LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix) + LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, + gammaShape) } // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? @@ -312,16 +313,23 @@ object LocalLDAModel extends Loader[LocalLDAModel] { // as a Row in data. case class Data(topic: Vector, index: Int) - // TODO: explicitly save docConcentration, topicConcentration, and gammaShape for use in - // model.predict() - def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = { + def save( + sc: SparkContext, + path: String, + topicsMatrix: Matrix, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double): Unit = { val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val k = topicsMatrix.numCols val metadata = compact(render (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ - ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows))) + ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows) ~ + ("docConcentration" -> docConcentration.toArray.toSeq) ~ + ("topicConcentration" -> topicConcentration) ~ + ("gammaShape" -> gammaShape))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix @@ -331,7 +339,12 @@ object LocalLDAModel extends Loader[LocalLDAModel] { sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path)) } - def load(sc: SparkContext, path: String): LocalLDAModel = { + def load( + sc: SparkContext, + path: String, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double): LocalLDAModel = { val dataPath = Loader.dataPath(path) val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) @@ -348,8 +361,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { val topicsMat = Matrices.fromBreeze(brzTopics) // TODO: initialize with docConcentration, topicConcentration, and gammaShape after SPARK-9940 - new LocalLDAModel(topicsMat, - Vectors.dense(Array.fill(topicsMat.numRows)(1.0 / topicsMat.numRows)), 1D, 100D) + new LocalLDAModel(topicsMat, docConcentration, topicConcentration, gammaShape) } } @@ -358,11 +370,15 @@ object LocalLDAModel extends Loader[LocalLDAModel] { implicit val formats = DefaultFormats val expectedK = (metadata \ "k").extract[Int] val expectedVocabSize = (metadata \ "vocabSize").extract[Int] + val docConcentration = + Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray) + val topicConcentration = (metadata \ "topicConcentration").extract[Double] + val gammaShape = (metadata \ "gammaShape").extract[Double] val classNameV1_0 = SaveLoadV1_0.thisClassName val model = (loadedClassName, loadedVersion) match { case (className, "1.0") if className == classNameV1_0 => - SaveLoadV1_0.load(sc, path) + SaveLoadV1_0.load(sc, path, docConcentration, topicConcentration, gammaShape) case _ => throw new Exception( s"LocalLDAModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $loadedVersion). Supported:\n" + @@ -565,7 +581,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val thisFormatVersion = "1.0" - val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel" + val thisClassName = "org.apache.spark.mllib.clustering.DistributedLDAModel" // Store globalTopicTotals as a Vector. case class Data(globalTopicTotals: Vector) @@ -591,7 +607,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { import sqlContext.implicits._ val metadata = compact(render - (("class" -> classNameV1_0) ~ ("version" -> thisFormatVersion) ~ + (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> k) ~ ("vocabSize" -> vocabSize) ~ ("docConcentration" -> docConcentration.toArray.toSeq) ~ ("topicConcentration" -> topicConcentration) ~ @@ -660,7 +676,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val topicConcentration = (metadata \ "topicConcentration").extract[Double] val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]] val gammaShape = (metadata \ "gammaShape").extract[Double] - val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + val classNameV1_0 = SaveLoadV1_0.thisClassName val model = (loadedClassName, loadedVersion) match { case (className, "1.0") if className == classNameV1_0 => { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index aa36336ebbee6..b91c7cefed22e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -334,7 +334,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { test("model save/load") { // Test for LocalLDAModel. val localModel = new LocalLDAModel(tinyTopics, - Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D) + Vectors.dense(Array.fill(tinyTopics.numRows)(0.01)), 0.5D, 10D) val tempDir1 = Utils.createTempDir() val path1 = tempDir1.toURI.toString @@ -360,6 +360,9 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(samelocalModel.topicsMatrix === localModel.topicsMatrix) assert(samelocalModel.k === localModel.k) assert(samelocalModel.vocabSize === localModel.vocabSize) + assert(samelocalModel.docConcentration === localModel.docConcentration) + assert(samelocalModel.topicConcentration === localModel.topicConcentration) + assert(samelocalModel.gammaShape === localModel.gammaShape) val sameDistributedModel = DistributedLDAModel.load(sc, path2) assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix) @@ -368,6 +371,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes) assert(distributedModel.docConcentration === sameDistributedModel.docConcentration) assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration) + assert(distributedModel.gammaShape === sameDistributedModel.gammaShape) assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals) val graph = distributedModel.graph From 9514d874f0cf61f1eb4ec4f5f66e053119f769c9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 29 Jul 2015 20:46:03 -0700 Subject: [PATCH 05/50] [SPARK-9458] Avoid object allocation in prefix generation. In our existing sort prefix generation code, we use expression's eval method to generate the prefix, which results in object allocation for every prefix. We can use the specialized getters available on InternalRow directly to avoid the object allocation. I also removed the FLOAT prefix, opting for converting float directly to double. Author: Reynold Xin Closes #7763 from rxin/sort-prefix and squashes the following commits: 5dc2f06 [Reynold Xin] [SPARK-9458] Avoid object allocation in prefix generation. --- .../unsafe/sort/PrefixComparators.java | 16 ------ .../unsafe/sort/PrefixComparatorsSuite.scala | 12 ----- .../execution/UnsafeExternalRowSorter.java | 2 +- .../spark/sql/execution/SortPrefixUtils.scala | 51 +++++++++---------- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../org/apache/spark/sql/execution/sort.scala | 5 +- .../execution/RowFormatConvertersSuite.scala | 2 +- .../execution/UnsafeExternalSortSuite.scala | 10 ++-- 8 files changed, 35 insertions(+), 67 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index bf1bc5dffba78..5624e067da2cc 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -31,7 +31,6 @@ private PrefixComparators() {} public static final StringPrefixComparator STRING = new StringPrefixComparator(); public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator(); - public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator(); public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); public static final class StringPrefixComparator extends PrefixComparator { @@ -78,21 +77,6 @@ public int compare(long a, long b) { public final long NULL_PREFIX = Long.MIN_VALUE; } - public static final class FloatPrefixComparator extends PrefixComparator { - @Override - public int compare(long aPrefix, long bPrefix) { - float a = Float.intBitsToFloat((int) aPrefix); - float b = Float.intBitsToFloat((int) bPrefix); - return Utils.nanSafeCompareFloats(a, b); - } - - public long computePrefix(float value) { - return Float.floatToIntBits(value) & 0xffffffffL; - } - - public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY); - } - public static final class DoublePrefixComparator extends PrefixComparator { @Override public int compare(long aPrefix, long bPrefix) { diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index dc03e374b51db..28fe9259453a6 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -48,18 +48,6 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } - test("float prefix comparator handles NaN properly") { - val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001) - val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff) - assert(nan1.isNaN) - assert(nan2.isNaN) - val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1) - val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2) - assert(nan1Prefix === nan2Prefix) - val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue) - assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1) - } - test("double prefix comparator handles NaNs properly") { val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 4c3f2c6557140..8342833246f7d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -121,7 +121,7 @@ Iterator sort() throws IOException { // here in order to prevent memory leaks. cleanupResources(); } - return new AbstractScalaRowIterator() { + return new AbstractScalaRowIterator() { private final int numFields = schema.length(); private UnsafeRow row = new UnsafeRow(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 2dee3542d6101..050d27f1460fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.expressions.{BoundReference, SortOrder} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} @@ -39,57 +39,54 @@ object SortPrefixUtils { sortOrder.dataType match { case StringType => PrefixComparators.STRING case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL - case FloatType => PrefixComparators.FLOAT - case DoubleType => PrefixComparators.DOUBLE + case FloatType | DoubleType => PrefixComparators.DOUBLE case _ => NoOpPrefixComparator } } def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = { + val bound = sortOrder.child.asInstanceOf[BoundReference] + val pos = bound.ordinal sortOrder.dataType match { - case StringType => (row: InternalRow) => { - PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String]) - } + case StringType => + (row: InternalRow) => { + PrefixComparators.STRING.computePrefix(row.getUTF8String(pos)) + } case BooleanType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1 + if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX + else if (row.getBoolean(pos)) 1 else 0 } case ByteType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Byte] + if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX else row.getByte(pos) } case ShortType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Short] + if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX else row.getShort(pos) } case IntegerType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Int] + if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX else row.getInt(pos) } case LongType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Long] + if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX else row.getLong(pos) } case FloatType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX - else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float]) + if (row.isNullAt(pos)) { + PrefixComparators.DOUBLE.NULL_PREFIX + } else { + PrefixComparators.DOUBLE.computePrefix(row.getFloat(pos).toDouble) + } } case DoubleType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX - else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double]) + if (row.isNullAt(pos)) { + PrefixComparators.DOUBLE.NULL_PREFIX + } else { + PrefixComparators.DOUBLE.computePrefix(row.getDouble(pos)) + } } case _ => (row: InternalRow) => 0L } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f3ef066528ff8..4ab2c41f1b339 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -340,8 +340,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && - UnsafeExternalSort.supportsSchema(child.schema)) { - execution.UnsafeExternalSort(sortExprs, global, child) + TungstenSort.supportsSchema(child.schema)) { + execution.TungstenSort(sortExprs, global, child) } else if (sqlContext.conf.externalSortEnabled) { execution.ExternalSort(sortExprs, global, child) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index f82208868c3e3..d0ad310062853 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -97,7 +97,7 @@ case class ExternalSort( * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will * spill every `frequency` records. */ -case class UnsafeExternalSort( +case class TungstenSort( sortOrder: Seq[SortOrder], global: Boolean, child: SparkPlan, @@ -110,7 +110,6 @@ case class UnsafeExternalSort( if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled") def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { val ordering = newOrdering(sortOrder, child.output) val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output) @@ -149,7 +148,7 @@ case class UnsafeExternalSort( } @DeveloperApi -object UnsafeExternalSort { +object TungstenSort { /** * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 7b75f755918c1..c458f95ca1ab3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -31,7 +31,7 @@ class RowFormatConvertersSuite extends SparkPlanTest { private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) + private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(outputsUnsafe.outputsUnsafeRows) test("planner should insert unsafe->safe conversions when required") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 7a4baa9e4a49d..9cabc4b90bf8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -42,7 +42,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), sortAnswers = false ) @@ -53,7 +53,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { try { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), sortAnswers = false ) @@ -68,7 +68,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { val stringLength = 1024 * 1024 * 2 checkThatPlansAgree( Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), - UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) @@ -88,11 +88,11 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) - assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) + assert(TungstenSort.supportsSchema(inputDf.schema)) checkThatPlansAgree( inputDf, plan => ConvertToSafe( - UnsafeExternalSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), + TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) From 07fd7d36471dfb823c1ce3e3a18464043affde18 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 29 Jul 2015 21:18:43 -0700 Subject: [PATCH 06/50] [SPARK-9460] Avoid byte array allocation in StringPrefixComparator. As of today, StringPrefixComparator converts the long values back to byte arrays in order to compare them. This patch optimizes this to compare the longs directly, rather than turning the longs into byte arrays and comparing them byte by byte (unsigned). This only works on little-endian architecture right now. Author: Reynold Xin Closes #7765 from rxin/SPARK-9460 and squashes the following commits: e4908cc [Reynold Xin] Stricter randomized tests. 4c8d094 [Reynold Xin] [SPARK-9460] Avoid byte array allocation in StringPrefixComparator. --- .../unsafe/sort/PrefixComparators.java | 29 ++----------------- .../unsafe/sort/PrefixComparatorsSuite.scala | 19 ++++++++---- .../apache/spark/unsafe/types/UTF8String.java | 9 ++++++ .../spark/unsafe/types/UTF8StringSuite.java | 11 +++++++ 4 files changed, 36 insertions(+), 32 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 5624e067da2cc..a9ee6042fec74 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -17,9 +17,7 @@ package org.apache.spark.util.collection.unsafe.sort; -import com.google.common.base.Charsets; -import com.google.common.primitives.Longs; -import com.google.common.primitives.UnsignedBytes; +import com.google.common.primitives.UnsignedLongs; import org.apache.spark.annotation.Private; import org.apache.spark.unsafe.types.UTF8String; @@ -36,32 +34,11 @@ private PrefixComparators() {} public static final class StringPrefixComparator extends PrefixComparator { @Override public int compare(long aPrefix, long bPrefix) { - // TODO: can done more efficiently - byte[] a = Longs.toByteArray(aPrefix); - byte[] b = Longs.toByteArray(bPrefix); - for (int i = 0; i < 8; i++) { - int c = UnsignedBytes.compare(a[i], b[i]); - if (c != 0) return c; - } - return 0; - } - - public long computePrefix(byte[] bytes) { - if (bytes == null) { - return 0L; - } else { - byte[] padded = new byte[8]; - System.arraycopy(bytes, 0, padded, 0, Math.min(bytes.length, 8)); - return Longs.fromByteArray(padded); - } - } - - public long computePrefix(String value) { - return value == null ? 0L : computePrefix(value.getBytes(Charsets.UTF_8)); + return UnsignedLongs.compare(aPrefix, bPrefix); } public long computePrefix(UTF8String value) { - return value == null ? 0L : computePrefix(value.getBytes()); + return value == null ? 0L : value.getPrefix(); } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index 28fe9259453a6..26b7a9e816d1e 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -17,22 +17,29 @@ package org.apache.spark.util.collection.unsafe.sort +import com.google.common.primitives.UnsignedBytes import org.scalatest.prop.PropertyChecks - import org.apache.spark.SparkFunSuite +import org.apache.spark.unsafe.types.UTF8String class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { test("String prefix comparator") { def testPrefixComparison(s1: String, s2: String): Unit = { - val s1Prefix = PrefixComparators.STRING.computePrefix(s1) - val s2Prefix = PrefixComparators.STRING.computePrefix(s2) + val utf8string1 = UTF8String.fromString(s1) + val utf8string2 = UTF8String.fromString(s2) + val s1Prefix = PrefixComparators.STRING.computePrefix(utf8string1) + val s2Prefix = PrefixComparators.STRING.computePrefix(utf8string2) val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix) + + val cmp = UnsignedBytes.lexicographicalComparator().compare( + utf8string1.getBytes.take(8), utf8string2.getBytes.take(8)) + assert( - (prefixComparisonResult == 0) || - (prefixComparisonResult < 0 && s1 < s2) || - (prefixComparisonResult > 0 && s1 > s2)) + (prefixComparisonResult == 0 && cmp == 0) || + (prefixComparisonResult < 0 && s1.compareTo(s2) < 0) || + (prefixComparisonResult > 0 && s1.compareTo(s2) > 0)) } // scalastyle:off diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 3e1cc67dbf337..57522003ba2ba 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -137,6 +137,15 @@ public int numChars() { return len; } + /** + * Returns a 64-bit integer that can be used as the prefix used in sorting. + */ + public long getPrefix() { + long p = PlatformDependent.UNSAFE.getLong(base, offset); + p = java.lang.Long.reverseBytes(p); + return p; + } + /** * Returns the underline bytes, will be a copy of it if it's part of another array. */ diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index e2a5628ff4d93..42e09e435a412 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -63,8 +63,19 @@ public void emptyStringTest() { assertEquals(0, EMPTY_UTF8.numBytes()); } + @Test + public void prefix() { + assertTrue(fromString("a").getPrefix() - fromString("b").getPrefix() < 0); + assertTrue(fromString("ab").getPrefix() - fromString("b").getPrefix() < 0); + assertTrue( + fromString("abbbbbbbbbbbasdf").getPrefix() - fromString("bbbbbbbbbbbbasdf").getPrefix() < 0); + assertTrue(fromString("").getPrefix() - fromString("a").getPrefix() < 0); + assertTrue(fromString("你好").getPrefix() - fromString("世界").getPrefix() > 0); + } + @Test public void compareTo() { + assertTrue(fromString("").compareTo(fromString("a")) < 0); assertTrue(fromString("abc").compareTo(fromString("ABC")) > 0); assertTrue(fromString("abc0").compareTo(fromString("abc")) > 0); assertTrue(fromString("abcabcabc").compareTo(fromString("abcabcabc")) == 0); From 27850af5255352cebd933ed3cc3d82c9ff6e9b62 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 29 Jul 2015 21:24:47 -0700 Subject: [PATCH 07/50] [SPARK-9462][SQL] Initialize nondeterministic expressions in code gen fallback mode. Author: Reynold Xin Closes #7767 from rxin/SPARK-9462 and squashes the following commits: ef3e2d9 [Reynold Xin] Removed println 713ac3a [Reynold Xin] More unit tests. bb5c334 [Reynold Xin] [SPARK-9462][SQL] Initialize nondeterministic expressions in code gen fallback mode. --- .../expressions/codegen/CodegenFallback.scala | 7 ++- .../CodegenExpressionCachingSuite.scala | 46 +++++++++++++++++-- 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 6b187f05604fd..3492d2c6189ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, Expression} /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -25,6 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.Expression trait CodegenFallback extends Expression { protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + foreach { + case n: Nondeterministic => n.setInitialValues() + case _ => + } + ctx.references += this val objectTerm = ctx.freshName("obj") s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala index 866bf904e4a4c..2d3f98dbbd3d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, LeafExpression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{BooleanType, DataType} /** @@ -27,7 +27,32 @@ import org.apache.spark.sql.types.{BooleanType, DataType} */ class CodegenExpressionCachingSuite extends SparkFunSuite { - test("GenerateUnsafeProjection") { + test("GenerateUnsafeProjection should initialize expressions") { + // Use an Add to wrap two of them together in case we only initialize the top level expressions. + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = UnsafeProjection.create(Seq(expr)) + assert(instance.apply(null).getBoolean(0) === false) + } + + test("GenerateProjection should initialize expressions") { + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = GenerateProjection.generate(Seq(expr)) + assert(instance.apply(null).getBoolean(0) === false) + } + + test("GenerateMutableProjection should initialize expressions") { + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = GenerateMutableProjection.generate(Seq(expr))() + assert(instance.apply(null).getBoolean(0) === false) + } + + test("GeneratePredicate should initialize expressions") { + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = GeneratePredicate.generate(expr) + assert(instance.apply(null) === false) + } + + test("GenerateUnsafeProjection should not share expression instances") { val expr1 = MutableExpression() val instance1 = UnsafeProjection.create(Seq(expr1)) assert(instance1.apply(null).getBoolean(0) === false) @@ -39,7 +64,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance2.apply(null).getBoolean(0) === true) } - test("GenerateProjection") { + test("GenerateProjection should not share expression instances") { val expr1 = MutableExpression() val instance1 = GenerateProjection.generate(Seq(expr1)) assert(instance1.apply(null).getBoolean(0) === false) @@ -51,7 +76,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance2.apply(null).getBoolean(0) === true) } - test("GenerateMutableProjection") { + test("GenerateMutableProjection should not share expression instances") { val expr1 = MutableExpression() val instance1 = GenerateMutableProjection.generate(Seq(expr1))() assert(instance1.apply(null).getBoolean(0) === false) @@ -63,7 +88,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance2.apply(null).getBoolean(0) === true) } - test("GeneratePredicate") { + test("GeneratePredicate should not share expression instances") { val expr1 = MutableExpression() val instance1 = GeneratePredicate.generate(expr1) assert(instance1.apply(null) === false) @@ -77,6 +102,17 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { } +/** + * An expression that's non-deterministic and doesn't support codegen. + */ +case class NondeterministicExpression() + extends LeafExpression with Nondeterministic with CodegenFallback { + override protected def initInternal(): Unit = { } + override protected def evalInternal(input: InternalRow): Any = false + override def nullable: Boolean = false + override def dataType: DataType = BooleanType +} + /** * An expression with mutable state so we can change it freely in our test suite. From f5dd11339fc9a6d11350f63beeca7c14aec169b1 Mon Sep 17 00:00:00 2001 From: Alex Angelini Date: Wed, 29 Jul 2015 22:25:38 -0700 Subject: [PATCH 08/50] Fix reference to self.names in StructType `names` is not defined in this context, I think you meant `self.names`. davies Author: Alex Angelini Closes #7766 from angelini/fix_struct_type_names and squashes the following commits: 01543a1 [Alex Angelini] Fix reference to self.names in StructType --- python/pyspark/sql/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index b97d50c945f24..8859308d66027 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -531,7 +531,7 @@ def toInternal(self, obj): if self._needSerializeFields: if isinstance(obj, dict): - return tuple(f.toInternal(obj.get(n)) for n, f in zip(names, self.fields)) + return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields)) elif isinstance(obj, (tuple, list)): return tuple(f.toInternal(v) for f, v in zip(self.fields, obj)) else: From e044705b4402f86d0557ecd146f3565388c7eeb4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 29 Jul 2015 22:30:49 -0700 Subject: [PATCH 09/50] [SPARK-9116] [SQL] [PYSPARK] support Python only UDT in __main__ Also we could create a Python UDT without having a Scala one, it's important for Python users. cc mengxr JoshRosen Author: Davies Liu Closes #7453 from davies/class_in_main and squashes the following commits: 4dfd5e1 [Davies Liu] add tests for Python and Scala UDT 793d9b2 [Davies Liu] Merge branch 'master' of github.com:apache/spark into class_in_main dc65f19 [Davies Liu] address comment a9a3c40 [Davies Liu] Merge branch 'master' of github.com:apache/spark into class_in_main a86e1fc [Davies Liu] fix serialization ad528ba [Davies Liu] Merge branch 'master' of github.com:apache/spark into class_in_main 63f52ef [Davies Liu] fix pylint check 655b8a9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into class_in_main 316a394 [Davies Liu] support Python UDT with UTF 0bcb3ef [Davies Liu] fix bug in mllib de986d6 [Davies Liu] fix test 83d65ac [Davies Liu] fix bug in StructType 55bb86e [Davies Liu] support Python UDT in __main__ (without Scala one) --- pylintrc | 2 +- python/pyspark/cloudpickle.py | 38 +++++- python/pyspark/shuffle.py | 2 +- python/pyspark/sql/context.py | 108 ++++++++++------- python/pyspark/sql/tests.py | 112 ++++++++++++++++-- python/pyspark/sql/types.py | 78 ++++++------ .../org/apache/spark/sql/types/DataType.scala | 9 ++ .../spark/sql/types/UserDefinedType.scala | 29 +++++ .../spark/sql/execution/pythonUDFs.scala | 1 - 9 files changed, 286 insertions(+), 93 deletions(-) diff --git a/pylintrc b/pylintrc index 061775960393b..6a675770da69a 100644 --- a/pylintrc +++ b/pylintrc @@ -84,7 +84,7 @@ enable= # If you would like to improve the code quality of pyspark, remove any of these disabled errors # run ./dev/lint-python and see if the errors raised by pylint can be fixed. -disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable +disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable,undefined-loop-variable [REPORTS] diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 9ef93071d2e77..3b647985801b7 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -350,7 +350,26 @@ def save_global(self, obj, name=None, pack=struct.pack): if new_override: d['__new__'] = obj.__new__ - self.save_reduce(typ, (obj.__name__, obj.__bases__, d), obj=obj) + self.save(_load_class) + self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj) + d.pop('__doc__', None) + # handle property and staticmethod + dd = {} + for k, v in d.items(): + if isinstance(v, property): + k = ('property', k) + v = (v.fget, v.fset, v.fdel, v.__doc__) + elif isinstance(v, staticmethod) and hasattr(v, '__func__'): + k = ('staticmethod', k) + v = v.__func__ + elif isinstance(v, classmethod) and hasattr(v, '__func__'): + k = ('classmethod', k) + v = v.__func__ + dd[k] = v + self.save(dd) + self.write(pickle.TUPLE2) + self.write(pickle.REDUCE) + else: raise pickle.PicklingError("Can't pickle %r" % obj) @@ -708,6 +727,23 @@ def _make_skel_func(code, closures, base_globals = None): None, None, closure) +def _load_class(cls, d): + """ + Loads additional properties into class `cls`. + """ + for k, v in d.items(): + if isinstance(k, tuple): + typ, k = k + if typ == 'property': + v = property(*v) + elif typ == 'staticmethod': + v = staticmethod(v) + elif typ == 'classmethod': + v = classmethod(v) + setattr(cls, k, v) + return cls + + """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 8fb71bac64a5e..b8118bdb7ca76 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -606,7 +606,7 @@ def _open_file(self): if not os.path.exists(d): os.makedirs(d) p = os.path.join(d, str(id(self))) - self._file = open(p, "wb+", 65536) + self._file = open(p, "w+b", 65536) self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024) os.unlink(p) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index abb6522dde7b0..917de24f3536b 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -277,6 +277,66 @@ def applySchema(self, rdd, schema): return self.createDataFrame(rdd, schema) + def _createFromRDD(self, rdd, schema, samplingRatio): + """ + Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. + """ + if schema is None or isinstance(schema, (list, tuple)): + struct = self._inferSchema(rdd, samplingRatio) + converter = _create_converter(struct) + rdd = rdd.map(converter) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct + + elif isinstance(schema, StructType): + # take the first few rows to verify schema + rows = rdd.take(10) + for row in rows: + _verify_type(row, schema) + + else: + raise TypeError("schema should be StructType or list or None, but got: %s" % schema) + + # convert python objects to sql data + rdd = rdd.map(schema.toInternal) + return rdd, schema + + def _createFromLocal(self, data, schema): + """ + Create an RDD for DataFrame from an list or pandas.DataFrame, returns + the RDD and schema. + """ + if has_pandas and isinstance(data, pandas.DataFrame): + if schema is None: + schema = [str(x) for x in data.columns] + data = [r.tolist() for r in data.to_records(index=False)] + + # make sure data could consumed multiple times + if not isinstance(data, list): + data = list(data) + + if schema is None or isinstance(schema, (list, tuple)): + struct = self._inferSchemaFromList(data) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct + + elif isinstance(schema, StructType): + for row in data: + _verify_type(row, schema) + + else: + raise TypeError("schema should be StructType or list or None, but got: %s" % schema) + + # convert python objects to sql data + data = [schema.toInternal(row) for row in data] + return self._sc.parallelize(data), schema + @since(1.3) @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None): @@ -340,49 +400,15 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") - if has_pandas and isinstance(data, pandas.DataFrame): - if schema is None: - schema = [str(x) for x in data.columns] - data = [r.tolist() for r in data.to_records(index=False)] - - if not isinstance(data, RDD): - if not isinstance(data, list): - data = list(data) - try: - # data could be list, tuple, generator ... - rdd = self._sc.parallelize(data) - except Exception: - raise TypeError("cannot create an RDD from type: %s" % type(data)) + if isinstance(data, RDD): + rdd, schema = self._createFromRDD(data, schema, samplingRatio) else: - rdd = data - - if schema is None or isinstance(schema, (list, tuple)): - if isinstance(data, RDD): - struct = self._inferSchema(rdd, samplingRatio) - else: - struct = self._inferSchemaFromList(data) - if isinstance(schema, (list, tuple)): - for i, name in enumerate(schema): - struct.fields[i].name = name - schema = struct - converter = _create_converter(schema) - rdd = rdd.map(converter) - - elif isinstance(schema, StructType): - # take the first few rows to verify schema - rows = rdd.take(10) - for row in rows: - _verify_type(row, schema) - - else: - raise TypeError("schema should be StructType or list or None") - - # convert python objects to sql data - rdd = rdd.map(schema.toInternal) - + rdd, schema = self._createFromLocal(data, schema) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return DataFrame(df, self) + jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + df = DataFrame(jdf, self) + df._schema = schema + return df @since(1.3) def registerDataFrameAsTable(self, df, tableName): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5aa6135dc1ee7..ebd3ea8db6a43 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -75,7 +75,7 @@ def sqlType(self): @classmethod def module(cls): - return 'pyspark.tests' + return 'pyspark.sql.tests' @classmethod def scalaUDT(cls): @@ -106,10 +106,45 @@ def __str__(self): return "(%s,%s)" % (self.x, self.y) def __eq__(self, other): - return isinstance(other, ExamplePoint) and \ + return isinstance(other, self.__class__) and \ other.x == self.x and other.y == self.y +class PythonOnlyUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return '__main__' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return PythonOnlyPoint(datum[0], datum[1]) + + @staticmethod + def foo(): + pass + + @property + def props(self): + return {} + + +class PythonOnlyPoint(ExamplePoint): + """ + An example class to demonstrate UDT in only Python + """ + __UDT__ = PythonOnlyUDT() + + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 def test_data_type_eq(self): @@ -395,10 +430,39 @@ def test_convert_row_to_dict(self): self.assertEqual(1, row.asDict()["l"][0].a) self.assertEqual(1.0, row.asDict()['d']['key'].c) + def test_udt(self): + from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type + from pyspark.sql.tests import ExamplePointUDT, ExamplePoint + + def check_datatype(datatype): + pickled = pickle.loads(pickle.dumps(datatype)) + assert datatype == pickled + scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json()) + python_datatype = _parse_datatype_json_string(scala_datatype.json()) + assert datatype == python_datatype + + check_datatype(ExamplePointUDT()) + structtype_with_udt = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + check_datatype(structtype_with_udt) + p = ExamplePoint(1.0, 2.0) + self.assertEqual(_infer_type(p), ExamplePointUDT()) + _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) + self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT())) + + check_datatype(PythonOnlyUDT()) + structtype_with_udt = StructType([StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), False)]) + check_datatype(structtype_with_udt) + p = PythonOnlyPoint(1.0, 2.0) + self.assertEqual(_infer_type(p), PythonOnlyUDT()) + _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) + self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) + def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.sc.parallelize([row]).toDF() + df = self.sqlCtx.createDataFrame([row]) schema = df.schema field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) @@ -406,36 +470,66 @@ def test_infer_schema_with_udt(self): point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), PythonOnlyUDT) + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + def test_apply_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = (1.0, ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - df = rdd.toDF(schema) + df = self.sqlCtx.createDataFrame([row], schema) point = df.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + row = (1.0, PythonOnlyPoint(1.0, 2.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), False)]) + df = self.sqlCtx.createDataFrame([row], schema) + point = df.head().point + self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + def test_udf_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.sc.parallelize([row]).toDF() + df = self.sqlCtx.createDataFrame([row]) self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) + udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) + self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + def test_parquet_with_udt(self): - from pyspark.sql.tests import ExamplePoint + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df0 = self.sc.parallelize([row]).toDF() + df0 = self.sqlCtx.createDataFrame([row]) output_dir = os.path.join(self.tempdir.name, "labeled_point") - df0.saveAsParquetFile(output_dir) + df0.write.parquet(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) point = df1.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df0 = self.sqlCtx.createDataFrame([row]) + df0.write.parquet(output_dir, mode='overwrite') + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point + self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + def test_column_operators(self): ci = self.df.key cs = self.df.value diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 8859308d66027..0976aea72c034 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -22,6 +22,7 @@ import calendar import json import re +import base64 from array import array if sys.version >= "3": @@ -31,6 +32,8 @@ from py4j.protocol import register_input_converter from py4j.java_gateway import JavaClass +from pyspark.serializers import CloudPickleSerializer + __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", @@ -458,7 +461,7 @@ def __init__(self, fields=None): self.names = [f.name for f in fields] assert all(isinstance(f, StructField) for f in fields),\ "fields should be a list of StructField" - self._needSerializeFields = None + self._needSerializeAnyField = any(f.needConversion() for f in self.fields) def add(self, field, data_type=None, nullable=True, metadata=None): """ @@ -501,6 +504,7 @@ def add(self, field, data_type=None, nullable=True, metadata=None): data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) self.names.append(field) + self._needSerializeAnyField = any(f.needConversion() for f in self.fields) return self def simpleString(self): @@ -526,10 +530,7 @@ def toInternal(self, obj): if obj is None: return - if self._needSerializeFields is None: - self._needSerializeFields = any(f.needConversion() for f in self.fields) - - if self._needSerializeFields: + if self._needSerializeAnyField: if isinstance(obj, dict): return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields)) elif isinstance(obj, (tuple, list)): @@ -550,7 +551,10 @@ def fromInternal(self, obj): if isinstance(obj, Row): # it's already converted by pickler return obj - values = [f.dataType.fromInternal(v) for f, v in zip(self.fields, obj)] + if self._needSerializeAnyField: + values = [f.fromInternal(v) for f, v in zip(self.fields, obj)] + else: + values = obj return _create_row(self.names, values) @@ -581,9 +585,10 @@ def module(cls): @classmethod def scalaUDT(cls): """ - The class name of the paired Scala UDT. + The class name of the paired Scala UDT (could be '', if there + is no corresponding one). """ - raise NotImplementedError("UDT must have a paired Scala UDT.") + return '' def needConversion(self): return True @@ -622,22 +627,37 @@ def json(self): return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) def jsonValue(self): - schema = { - "type": "udt", - "class": self.scalaUDT(), - "pyClass": "%s.%s" % (self.module(), type(self).__name__), - "sqlType": self.sqlType().jsonValue() - } + if self.scalaUDT(): + assert self.module() != '__main__', 'UDT in __main__ cannot work with ScalaUDT' + schema = { + "type": "udt", + "class": self.scalaUDT(), + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "sqlType": self.sqlType().jsonValue() + } + else: + ser = CloudPickleSerializer() + b = ser.dumps(type(self)) + schema = { + "type": "udt", + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "serializedClass": base64.b64encode(b).decode('utf8'), + "sqlType": self.sqlType().jsonValue() + } return schema @classmethod def fromJson(cls, json): - pyUDT = json["pyClass"] + pyUDT = str(json["pyClass"]) split = pyUDT.rfind(".") pyModule = pyUDT[:split] pyClass = pyUDT[split+1:] m = __import__(pyModule, globals(), locals(), [pyClass]) - UDT = getattr(m, pyClass) + if not hasattr(m, pyClass): + s = base64.b64decode(json['serializedClass'].encode('utf-8')) + UDT = CloudPickleSerializer().loads(s) + else: + UDT = getattr(m, pyClass) return UDT() def __eq__(self, other): @@ -696,11 +716,6 @@ def _parse_datatype_json_string(json_string): >>> complex_maptype = MapType(complex_structtype, ... complex_arraytype, False) >>> check_datatype(complex_maptype) - - >>> check_datatype(ExamplePointUDT()) - >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> check_datatype(structtype_with_udt) """ return _parse_datatype_json_value(json.loads(json_string)) @@ -752,10 +767,6 @@ def _parse_datatype_json_value(json_value): def _infer_type(obj): """Infer the DataType from obj - - >>> p = ExamplePoint(1.0, 2.0) - >>> _infer_type(p) - ExamplePointUDT """ if obj is None: return NullType() @@ -1090,11 +1101,6 @@ def _verify_type(obj, dataType): Traceback (most recent call last): ... ValueError:... - >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) - >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... """ # all objects are nullable if obj is None: @@ -1259,18 +1265,12 @@ def convert(self, obj, gateway_client): def _test(): import doctest from pyspark.context import SparkContext - # let doctest run in pyspark.sql.types, so DataTypes can be picklable - import pyspark.sql.types - from pyspark.sql import Row, SQLContext - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - globs = pyspark.sql.types.__dict__.copy() + from pyspark.sql import SQLContext + globs = globals() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) - globs['ExamplePoint'] = ExamplePoint - globs['ExamplePointUDT'] = ExamplePointUDT - (failure_count, test_count) = doctest.testmod( - pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 591fb26e67c4a..f4428c2e8b202 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -142,12 +142,21 @@ object DataType { ("type", JString("struct"))) => StructType(fields.map(parseStructField)) + // Scala/Java UDT case JSortedObject( ("class", JString(udtClass)), ("pyClass", _), ("sqlType", _), ("type", JString("udt"))) => Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] + + // Python UDT + case JSortedObject( + ("pyClass", JString(pyClass)), + ("serializedClass", JString(serialized)), + ("sqlType", v: JValue), + ("type", JString("udt"))) => + new PythonUserDefinedType(parseDataType(v), pyClass, serialized) } private def parseStructField(json: JValue): StructField = json match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index e47cfb4833bd8..4305903616bd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -45,6 +45,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** Paired Python UDT class, if exists. */ def pyUDT: String = null + /** Serialized Python UDT class, if exists. */ + def serializedPyClass: String = null + /** * Convert the user type to a SQL datum * @@ -82,3 +85,29 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { override private[sql] def acceptsType(dataType: DataType) = this.getClass == dataType.getClass } + +/** + * ::DeveloperApi:: + * The user defined type in Python. + * + * Note: This can only be accessed via Python UDF, or accessed as serialized object. + */ +private[sql] class PythonUserDefinedType( + val sqlType: DataType, + override val pyUDT: String, + override val serializedPyClass: String) extends UserDefinedType[Any] { + + /* The serialization is handled by UDT class in Python */ + override def serialize(obj: Any): Any = obj + override def deserialize(datam: Any): Any = datam + + /* There is no Java class for Python UDT */ + override def userClass: java.lang.Class[Any] = null + + override private[sql] def jsonValue: JValue = { + ("type" -> "udt") ~ + ("pyClass" -> pyUDT) ~ + ("serializedClass" -> serializedPyClass) ~ + ("sqlType" -> sqlType.jsonValue) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index ec084a299649e..3c38916fd7504 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -267,7 +267,6 @@ object EvaluatePython { pickler.save(row.values(i)) i += 1 } - row.values.foreach(pickler.save) out.write(Opcodes.TUPLE) out.write(Opcodes.REDUCE) } From 712465b68e50df7a2050b27528acda9f0d95ba1f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 29 Jul 2015 22:51:06 -0700 Subject: [PATCH 10/50] HOTFIX: disable HashedRelationSuite. --- .../spark/sql/execution/joins/HashedRelationSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 8b1a9b21a96b9..941f6d4f6a450 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -33,7 +33,7 @@ class HashedRelationSuite extends SparkFunSuite { override def apply(row: InternalRow): InternalRow = row } - test("GeneralHashedRelation") { + ignore("GeneralHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) @@ -47,7 +47,7 @@ class HashedRelationSuite extends SparkFunSuite { assert(hashed.get(data(2)) === data2) } - test("UniqueKeyHashedRelation") { + ignore("UniqueKeyHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) @@ -64,7 +64,7 @@ class HashedRelationSuite extends SparkFunSuite { assert(uniqHashed.getValue(InternalRow(10)) === null) } - test("UnsafeHashedRelation") { + ignore("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val toUnsafe = UnsafeProjection.create(schema) From e127ec34d58ceb0a9d45748c2f2918786ba0a83d Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 29 Jul 2015 23:24:20 -0700 Subject: [PATCH 11/50] [SPARK-9428] [SQL] Add test cases for null inputs for expression unit tests JIRA: https://issues.apache.org/jira/browse/SPARK-9428 Author: Yijie Shen Closes #7748 from yjshen/string_cleanup and squashes the following commits: e0c2b3d [Yijie Shen] update codegen in RegExpExtract and RegExpReplace 26614d2 [Yijie Shen] MathFunctionSuite a402859 [Yijie Shen] complex_create, conditional and cast 6e4e608 [Yijie Shen] arithmetic and cast 52593c1 [Yijie Shen] null input test cases for StringExpressionSuite --- .../spark/sql/catalyst/expressions/Cast.scala | 12 ++-- .../expressions/complexTypeCreator.scala | 16 +++-- .../catalyst/expressions/conditionals.scala | 10 +-- .../spark/sql/catalyst/expressions/math.scala | 14 ++--- .../expressions/stringOperations.scala | 11 ++-- .../ExpressionTypeCheckingSuite.scala | 7 ++- .../ArithmeticExpressionSuite.scala | 3 + .../sql/catalyst/expressions/CastSuite.scala | 52 ++++++++++++++- .../expressions/ComplexTypeSuite.scala | 23 +++---- .../ConditionalExpressionSuite.scala | 4 ++ .../expressions/MathFunctionsSuite.scala | 63 ++++++++++--------- .../catalyst/expressions/RandomSuite.scala | 1 - .../expressions/StringExpressionsSuite.scala | 26 ++++++++ .../org/apache/spark/sql/functions.scala | 6 +- 14 files changed, 167 insertions(+), 81 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index c6e8af27667ee..8c01c13c9ccd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -599,7 +599,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" case _: IntegralType => (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};" case DateType => @@ -665,7 +665,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;" case DateType => (c, evPrim, evNull) => s"$evNull = true;" case TimestampType => @@ -687,7 +687,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;" case DateType => (c, evPrim, evNull) => s"$evNull = true;" case TimestampType => @@ -731,7 +731,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" case DateType => (c, evPrim, evNull) => s"$evNull = true;" case TimestampType => @@ -753,7 +753,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? 1.0f : 0.0f;" case DateType => (c, evPrim, evNull) => s"$evNull = true;" case TimestampType => @@ -775,7 +775,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? 1.0d : 0.0d;" case DateType => (c, evPrim, evNull) => s"$evNull = true;" case TimestampType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index d8c9087ff5380..0517050a45109 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.unsafe.types.UTF8String + import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow @@ -127,11 +129,12 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { private lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip - private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) + private lazy val names = nameExprs.map(_.eval(EmptyRow)) override lazy val dataType: StructType = { val fields = names.zip(valExprs).map { case (name, valExpr) => - StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + StructField(name.asInstanceOf[UTF8String].toString, + valExpr.dataType, valExpr.nullable, Metadata.empty) } StructType(fields) } @@ -144,14 +147,15 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") } else { - val invalidNames = - nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable) + val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( - s"Odd position only allow foldable and not-null StringType expressions, got :" + + s"Only foldable StringType expressions are allowed to appear at odd position , got :" + s" ${invalidNames.mkString(",")}") - } else { + } else if (names.forall(_ != null)){ TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure("Field name should not be null") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index 15b33da884dcb..961b1d8616801 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -315,7 +315,6 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW * It takes at least 2 parameters, and returns null iff all parameters are null. */ case class Least(children: Seq[Expression]) extends Expression { - require(children.length > 1, "LEAST requires at least 2 arguments, got " + children.length) override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -323,7 +322,9 @@ case class Least(children: Seq[Expression]) extends Expression { private lazy val ordering = TypeUtils.getOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { - if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + if (children.length <= 1) { + TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments") + } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got LEAST (${children.map(_.dataType)}).") @@ -369,7 +370,6 @@ case class Least(children: Seq[Expression]) extends Expression { * It takes at least 2 parameters, and returns null iff all parameters are null. */ case class Greatest(children: Seq[Expression]) extends Expression { - require(children.length > 1, "GREATEST requires at least 2 arguments, got " + children.length) override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -377,7 +377,9 @@ case class Greatest(children: Seq[Expression]) extends Expression { private lazy val ordering = TypeUtils.getOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { - if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + if (children.length <= 1) { + TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments") + } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got GREATEST (${children.map(_.dataType)}).") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 68cca0ad3d067..e6d807f6d897b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -646,19 +646,19 @@ case class Logarithm(left: Expression, right: Expression) /** * Round the `child`'s result to `scale` decimal place when `scale` >= 0 * or round at integral part when `scale` < 0. - * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30. + * For example, round(31.415, 2) = 31.42 and round(31.415, -1) = 30. * - * Child of IntegralType would eval to itself when `scale` >= 0. - * Child of FractionalType whose value is NaN or Infinite would always eval to itself. + * Child of IntegralType would round to itself when `scale` >= 0. + * Child of FractionalType whose value is NaN or Infinite would always round to itself. * - * Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]], - * which leads to scale update in DecimalType's [[PrecisionInfo]] + * Round's dataType would always equal to `child`'s dataType except for DecimalType, + * which would lead scale decrease from the origin DecimalType. * * @param child expr to be round, all [[NumericType]] is allowed as Input * @param scale new scale to be round to, this should be a constant int at runtime */ case class Round(child: Expression, scale: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { import BigDecimal.RoundingMode.HALF_UP @@ -838,6 +838,4 @@ case class Round(child: Expression, scale: Expression) """ } } - - override def prettyName: String = "round" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 6db4e19c24ed5..5b3a64a09679c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -22,7 +22,6 @@ import java.util.Locale import java.util.regex.{MatchResult, Pattern} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -52,7 +51,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val evals = children.map(_.gen(ctx)) val inputs = evals.map { eval => - s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}" + s"${eval.isNull} ? null : ${eval.primitive}" }.mkString(", ") evals.map(_.code).mkString("\n") + s""" boolean ${ev.isNull} = false; @@ -1008,7 +1007,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio s""" ${evalSubject.code} - boolean ${ev.isNull} = ${evalSubject.isNull}; + boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${evalSubject.isNull}) { ${evalRegexp.code} @@ -1103,9 +1102,9 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio val evalIdx = idx.gen(ctx) s""" - ${ctx.javaType(dataType)} ${ev.primitive} = null; - boolean ${ev.isNull} = true; ${evalSubject.code} + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + boolean ${ev.isNull} = true; if (!${evalSubject.isNull}) { ${evalRegexp.code} if (!${evalRegexp.isNull}) { @@ -1117,7 +1116,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); } ${classOf[java.util.regex.Matcher].getCanonicalName} m = - ${termPattern}.matcher(${evalSubject.primitive}.toString()); + ${termPattern}.matcher(${evalSubject.primitive}.toString()); if (m.find()) { ${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult(); ${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive})); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8acd4c685e2bc..a52e4cb4dfd9f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -167,10 +167,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") assertError( CreateNamedStruct(Seq(1, "a", "b", 2.0)), - "Odd position only allow foldable and not-null StringType expressions") + "Only foldable StringType expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), - "Odd position only allow foldable and not-null StringType expressions") + "Only foldable StringType expressions are allowed to appear at odd position") + assertError( + CreateNamedStruct(Seq(Literal.create(null, StringType), "a")), + "Field name should not be null") } test("check types for ROUND") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 7773e098e0caa..d03b0fbbfb2b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -116,9 +116,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper test("Abs") { testNumericDataTypes { convert => + val input = Literal(convert(1)) + val dataType = input.dataType checkEvaluation(Abs(Literal(convert(0))), convert(0)) checkEvaluation(Abs(Literal(convert(1))), convert(1)) checkEvaluation(Abs(Literal(convert(-1))), convert(1)) + checkEvaluation(Abs(Literal.create(null, dataType)), null) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 0e0213be0f57b..a517da9872852 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -43,6 +43,42 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(v, Literal(expected).dataType), expected) } + private def checkNullCast(from: DataType, to: DataType): Unit = { + checkEvaluation(Cast(Literal.create(null, from), to), null) + } + + test("null cast") { + import DataTypeTestUtils._ + + // follow [[org.apache.spark.sql.catalyst.expressions.Cast.canCast]] logic + // to ensure we test every possible cast situation here + atomicTypes.zip(atomicTypes).foreach { case (from, to) => + checkNullCast(from, to) + } + + atomicTypes.foreach(dt => checkNullCast(NullType, dt)) + atomicTypes.foreach(dt => checkNullCast(dt, StringType)) + checkNullCast(StringType, BinaryType) + checkNullCast(StringType, BooleanType) + checkNullCast(DateType, BooleanType) + checkNullCast(TimestampType, BooleanType) + numericTypes.foreach(dt => checkNullCast(dt, BooleanType)) + + checkNullCast(StringType, TimestampType) + checkNullCast(BooleanType, TimestampType) + checkNullCast(DateType, TimestampType) + numericTypes.foreach(dt => checkNullCast(dt, TimestampType)) + + atomicTypes.foreach(dt => checkNullCast(dt, DateType)) + + checkNullCast(StringType, CalendarIntervalType) + numericTypes.foreach(dt => checkNullCast(StringType, dt)) + numericTypes.foreach(dt => checkNullCast(BooleanType, dt)) + numericTypes.foreach(dt => checkNullCast(DateType, dt)) + numericTypes.foreach(dt => checkNullCast(TimestampType, dt)) + for (from <- numericTypes; to <- numericTypes) checkNullCast(from, to) + } + test("cast string to date") { var c = Calendar.getInstance() c.set(2015, 0, 1, 0, 0, 0) @@ -69,8 +105,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast string to timestamp") { - checkEvaluation(Cast(Literal("123"), TimestampType), - null) + checkEvaluation(Cast(Literal("123"), TimestampType), null) var c = Calendar.getInstance() c.set(2015, 0, 1, 0, 0, 0) @@ -473,6 +508,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val array_notNull = Literal.create(Seq("123", "abc", ""), ArrayType(StringType, containsNull = false)) + checkNullCast(ArrayType(StringType), ArrayType(IntegerType)) + { val ret = cast(array, ArrayType(IntegerType, containsNull = true)) assert(ret.resolved === true) @@ -526,6 +563,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { Map("a" -> "123", "b" -> "abc", "c" -> ""), MapType(StringType, StringType, valueContainsNull = false)) + checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType)) + { val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = true)) assert(ret.resolved === true) @@ -580,6 +619,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast from struct") { + checkNullCast( + StructType(Seq( + StructField("a", StringType), + StructField("b", IntegerType))), + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType)))) + val struct = Literal.create( InternalRow( UTF8String.fromString("123"), @@ -728,5 +775,4 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType), "interval 1 years 3 months -3 days") } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index fc842772f3480..5de5ddce975d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -132,6 +132,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow) checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow) checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil) } test("CreateStruct") { @@ -139,26 +140,20 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val c1 = 'a.int.at(0) val c3 = 'c.int.at(2) checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row) + checkEvaluation(CreateStruct(Literal.create(null, LongType) :: Nil), create_row(null)) } test("CreateNamedStruct") { - val row = InternalRow(1, 2, 3) + val row = create_row(1, 2, 3) val c1 = 'a.int.at(0) val c3 = 'c.int.at(2) - checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), InternalRow(1, 3), row) - } - - test("CreateNamedStruct with literal field") { - val row = InternalRow(1, 2, 3) - val c1 = 'a.int.at(0) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), create_row(1, 3), row) checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), - InternalRow(1, UTF8String.fromString("y")), row) - } - - test("CreateNamedStruct from all literal fields") { - checkEvaluation( - CreateNamedStruct(Seq("a", "x", "b", 2.0)), - InternalRow(UTF8String.fromString("x"), 2.0), InternalRow.empty) + create_row(1, UTF8String.fromString("y")), row) + checkEvaluation(CreateNamedStruct(Seq("a", "x", "b", 2.0)), + create_row(UTF8String.fromString("x"), 2.0)) + checkEvaluation(CreateNamedStruct(Seq("a", Literal.create(null, IntegerType))), + create_row(null)) } test("test dsl for complex type") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index b31d6661c8c1c..d26bcdb2902ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -149,6 +149,8 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Least(Seq(c1, c2, Literal(-1))), -1, row) checkEvaluation(Least(Seq(c4, c5, c3, c3, Literal("a"))), "a", row) + val nullLiteral = Literal.create(null, IntegerType) + checkEvaluation(Least(Seq(nullLiteral, nullLiteral)), null) checkEvaluation(Least(Seq(Literal(null), Literal(null))), null, InternalRow.empty) checkEvaluation(Least(Seq(Literal(-1.0), Literal(2.5))), -1.0, InternalRow.empty) checkEvaluation(Least(Seq(Literal(-1), Literal(2))), -1, InternalRow.empty) @@ -188,6 +190,8 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Greatest(Seq(c1, c2, Literal(2))), 2, row) checkEvaluation(Greatest(Seq(c4, c5, c3, Literal("ccc"))), "ccc", row) + val nullLiteral = Literal.create(null, IntegerType) + checkEvaluation(Greatest(Seq(nullLiteral, nullLiteral)), null) checkEvaluation(Greatest(Seq(Literal(null), Literal(null))), null, InternalRow.empty) checkEvaluation(Greatest(Seq(Literal(-1.0), Literal(2.5))), 2.5, InternalRow.empty) checkEvaluation(Greatest(Seq(Literal(-1), Literal(2))), 2, InternalRow.empty) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 21459a7c69838..9fcb548af6bbb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -110,35 +110,17 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) } - test("conv") { - checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11") - checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F") - checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1") - checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48") - checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null) - checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null) - checkEvaluation( - Conv(Literal("1234"), Literal(10), Literal(37)), null) - checkEvaluation( - Conv(Literal(""), Literal(10), Literal(16)), null) - checkEvaluation( - Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF") - // If there is an invalid digit in the number, the longest valid prefix should be converted. - checkEvaluation( - Conv(Literal("11abc"), Literal(10), Literal(16)), "B") - } - private def checkNaN( - expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { + expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { checkNaNWithoutCodegen(expression, inputRow) checkNaNWithGeneratedProjection(expression, inputRow) checkNaNWithOptimization(expression, inputRow) } private def checkNaNWithoutCodegen( - expression: Expression, - expected: Any, - inputRow: InternalRow = EmptyRow): Unit = { + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } @@ -149,7 +131,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - private def checkNaNWithGeneratedProjection( expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { @@ -172,6 +153,25 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) } + test("conv") { + checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11") + checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F") + checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1") + checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48") + checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null) + checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null) + checkEvaluation(Conv(Literal("3"), Literal(16), Literal.create(null, IntegerType)), null) + checkEvaluation( + Conv(Literal("1234"), Literal(10), Literal(37)), null) + checkEvaluation( + Conv(Literal(""), Literal(10), Literal(16)), null) + checkEvaluation( + Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF") + // If there is an invalid digit in the number, the longest valid prefix should be converted. + checkEvaluation( + Conv(Literal("11abc"), Literal(10), Literal(16)), "B") + } + test("e") { testLeaf(EulerNumber, math.E) } @@ -417,7 +417,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("round") { - val domain = -6 to 6 + val scales = -6 to 6 val doublePi: Double = math.Pi val shortPi: Short = 31415 val intPi: Int = 314159265 @@ -437,17 +437,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++ Seq.fill(7)(31415926535897932L) - val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), - BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), - BigDecimal(3.141593), BigDecimal(3.1415927)) - - domain.zipWithIndex.foreach { case (scale, i) => + scales.zipWithIndex.foreach { case (scale, i) => checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow) checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) } + val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), + BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), + BigDecimal(3.141593), BigDecimal(3.1415927)) // round_scale > current_scale would result in precision increase // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null (0 to 7).foreach { i => @@ -456,5 +455,11 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { (8 to 10).foreach { scale => checkEvaluation(Round(bdPi, scale), null, EmptyRow) } + + DataTypeTestUtils.numericTypes.foreach { dataType => + checkEvaluation(Round(Literal.create(null, dataType), Literal(2)), null) + checkEvaluation(Round(Literal.create(null, dataType), + Literal.create(null, IntegerType)), null) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index 5db992654811a..4a644d136f09c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -21,7 +21,6 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite - class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { test("random") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 3d294fda5d103..07b952531ec2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -348,6 +348,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row(" 花花世界 ")) checkEvaluation(StringTrim(s), "花花世界", create_row(" 花花世界 ")) // scalastyle:on + checkEvaluation(StringTrim(Literal.create(null, StringType)), null) + checkEvaluation(StringTrimLeft(Literal.create(null, StringType)), null) + checkEvaluation(StringTrimRight(Literal.create(null, StringType)), null) } test("FORMAT") { @@ -391,6 +394,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val s3 = 'c.string.at(2) val s4 = 'd.int.at(3) val row1 = create_row("aaads", "aa", "zz", 1) + val row2 = create_row(null, "aa", "zz", 0) + val row3 = create_row("aaads", null, "zz", 0) + val row4 = create_row(null, null, null, 0) checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1) checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1) @@ -402,6 +408,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringLocate(s2, s1, s4), 2, row1) checkEvaluation(new StringLocate(s3, s1), 0, row1) checkEvaluation(StringLocate(s3, s1, Literal.create(null, IntegerType)), 0, row1) + checkEvaluation(new StringLocate(s2, s1), null, row2) + checkEvaluation(new StringLocate(s2, s1), null, row3) + checkEvaluation(new StringLocate(s2, s1, Literal.create(null, IntegerType)), 0, row4) } test("LPAD/RPAD") { @@ -448,6 +457,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row1 = create_row("abccc") checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1) checkEvaluation(StringReverse(s), "cccba", row1) + checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1) } test("SPACE") { @@ -466,6 +476,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row1 = create_row("100-200", "(\\d+)", "num") val row2 = create_row("100-200", "(\\d+)", "###") val row3 = create_row("100-200", "(-)", "###") + val row4 = create_row(null, "(\\d+)", "###") + val row5 = create_row("100-200", null, "###") + val row6 = create_row("100-200", "(-)", null) val s = 's.string.at(0) val p = 'p.string.at(1) @@ -475,6 +488,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, "num-num", row1) checkEvaluation(expr, "###-###", row2) checkEvaluation(expr, "100###200", row3) + checkEvaluation(expr, null, row4) + checkEvaluation(expr, null, row5) + checkEvaluation(expr, null, row6) } test("RegexExtract") { @@ -482,6 +498,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2) val row3 = create_row("100-200", "(\\d+).*", 1) val row4 = create_row("100-200", "([a-z])", 1) + val row5 = create_row(null, "([a-z])", 1) + val row6 = create_row("100-200", null, 1) + val row7 = create_row("100-200", "([a-z])", null) val s = 's.string.at(0) val p = 'p.string.at(1) @@ -492,6 +511,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, "200", row2) checkEvaluation(expr, "100", row3) checkEvaluation(expr, "", row4) // will not match anything, empty string get + checkEvaluation(expr, null, row5) + checkEvaluation(expr, null, row6) + checkEvaluation(expr, null, row7) val expr1 = new RegExpExtract(s, p) checkEvaluation(expr1, "100", row1) @@ -501,11 +523,15 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val s1 = 'a.string.at(0) val s2 = 'b.string.at(1) val row1 = create_row("aa2bb3cc", "[1-9]+") + val row2 = create_row(null, "[1-9]+") + val row3 = create_row("aa2bb3cc", null) checkEvaluation( StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1) checkEvaluation( StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) + checkEvaluation(StringSplit(s1, s2), null, row2) + checkEvaluation(StringSplit(s1, s2), null, row3) } test("length for string / binary") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4261a5e7cbeb5..4e68a88e7cda6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1423,7 +1423,8 @@ object functions { def round(columnName: String): Column = round(Column(columnName), 0) /** - * Returns the value of `e` rounded to `scale` decimal places. + * Round the value of `e` to `scale` decimal places if `scale` >= 0 + * or at integral part when `scale` < 0. * * @group math_funcs * @since 1.5.0 @@ -1431,7 +1432,8 @@ object functions { def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) /** - * Returns the value of the given column rounded to `scale` decimal places. + * Round the value of the given column to `scale` decimal places if `scale` >= 0 + * or at integral part when `scale` < 0. * * @group math_funcs * @since 1.5.0 From 1221849f91739454b8e495889cba7498ba8beea7 Mon Sep 17 00:00:00 2001 From: Joseph Batchik Date: Wed, 29 Jul 2015 23:35:55 -0700 Subject: [PATCH 12/50] [SPARK-8005][SQL] Input file name Users can now get the file name of the partition being read in. A thread local variable is in `SQLNewHadoopRDD` and is set when the partition is computed. `SQLNewHadoopRDD` is moved to core so that the catalyst package can reach it. This supports: `df.select(inputFileName())` and `sqlContext.sql("select input_file_name() from table")` Author: Joseph Batchik Closes #7743 from JDrit/input_file_name and squashes the following commits: abb8609 [Joseph Batchik] fixed failing test and changed the default value to be an empty string d2f323d [Joseph Batchik] updates per review 102061f [Joseph Batchik] updates per review 75313f5 [Joseph Batchik] small fixes c7f7b5a [Joseph Batchik] addeding input file name to Spark SQL --- .../apache/spark/rdd}/SqlNewHadoopRDD.scala | 34 +++++++++++-- .../catalyst/analysis/FunctionRegistry.scala | 3 +- .../catalyst/expressions/InputFileName.scala | 49 +++++++++++++++++++ .../expressions/SparkPartitionID.scala | 2 + .../expressions/NondeterministicSuite.scala | 4 ++ .../org/apache/spark/sql/functions.scala | 9 ++++ .../spark/sql/parquet/ParquetRelation.scala | 3 +- .../spark/sql/ColumnExpressionSuite.scala | 17 ++++++- .../scala/org/apache/spark/sql/UDFSuite.scala | 17 ++++++- .../org/apache/spark/sql/hive/UDFSuite.scala | 6 --- 10 files changed, 128 insertions(+), 16 deletions(-) rename {sql/core/src/main/scala/org/apache/spark/sql/execution => core/src/main/scala/org/apache/spark/rdd}/SqlNewHadoopRDD.scala (91%) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala similarity index 91% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala rename to core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 3d75b6a91def6..35e44cb59c1be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -15,12 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date -import org.apache.spark.{Partition => SparkPartition, _} +import scala.reflect.ClassTag + import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -30,12 +31,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, Utils} -import scala.reflect.ClassTag private[spark] class SqlNewHadoopPartition( rddId: Int, @@ -62,7 +63,7 @@ private[spark] class SqlNewHadoopPartition( * changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be * folded into core. */ -private[sql] class SqlNewHadoopRDD[K, V]( +private[spark] class SqlNewHadoopRDD[K, V]( @transient sc : SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], @transient initDriverSideJobFuncOpt: Option[Job => Unit], @@ -128,6 +129,12 @@ private[sql] class SqlNewHadoopRDD[K, V]( val inputMetrics = context.taskMetrics .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + // Sets the thread local variable for the file's name + split.serializableHadoopSplit.value match { + case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDD.unsetInputFileName() + } + // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { @@ -188,6 +195,8 @@ private[sql] class SqlNewHadoopRDD[K, V]( reader.close() reader = null + SqlNewHadoopRDD.unsetInputFileName() + if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || @@ -250,6 +259,21 @@ private[sql] class SqlNewHadoopRDD[K, V]( } private[spark] object SqlNewHadoopRDD { + + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { + override protected def initialValue(): UTF8String = UTF8String.fromString("") + } + + def getInputFileName(): UTF8String = inputFileName.get() + + private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + + private[spark] def unsetInputFileName(): Unit = inputFileName.remove() + /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to * the given function rather than the index of the partition. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 372f80d4a8b16..378df4f57d9e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -230,7 +230,8 @@ object FunctionRegistry { expression[Sha1]("sha"), expression[Sha1]("sha1"), expression[Sha2]("sha2"), - expression[SparkPartitionID]("spark_partition_id") + expression[SparkPartitionID]("spark_partition_id"), + expression[InputFileName]("input_file_name") ) val builtin: FunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala new file mode 100644 index 0000000000000..1e74f716955e3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.rdd.SqlNewHadoopRDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * Expression that returns the name of the current file being read in using [[SqlNewHadoopRDD]] + */ +case class InputFileName() extends LeafExpression with Nondeterministic { + + override def nullable: Boolean = true + + override def dataType: DataType = StringType + + override val prettyName = "INPUT_FILE_NAME" + + override protected def initInternal(): Unit = {} + + override protected def evalInternal(input: InternalRow): UTF8String = { + SqlNewHadoopRDD.getInputFileName() + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + ev.isNull = "false" + s"final ${ctx.javaType(dataType)} ${ev.primitive} = " + + "org.apache.spark.rdd.SqlNewHadoopRDD.getInputFileName();" + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 3f6480bbf0114..4b1772a2deed5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -34,6 +34,8 @@ private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterm @transient private[this] var partitionId: Int = _ + override val prettyName = "SPARK_PARTITION_ID" + override protected def initInternal(): Unit = { partitionId = TaskContext.getPartitionId() } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala index 82894822ab0f4..bf1c930c0bd0b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala @@ -27,4 +27,8 @@ class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper { test("SparkPartitionID") { checkEvaluation(SparkPartitionID(), 0) } + + test("InputFileName") { + checkEvaluation(InputFileName(), "") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4e68a88e7cda6..a2fece62f61f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -743,6 +743,15 @@ object functions { */ def sparkPartitionId(): Column = SparkPartitionID() + /** + * The file name of the current Spark task + * + * Note that this is indeterministic becuase it depends on what is currently being read in. + * + * @group normal_funcs + */ + def inputFileName(): Column = InputFileName() + /** * Computes the square root of the specified float value. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index cc6fa2b88663f..1a8176d8a80ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -39,11 +39,10 @@ import org.apache.parquet.{Log => ParquetLog} import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{SqlNewHadoopPartition, SqlNewHadoopRDD, RDD} import org.apache.spark.rdd.RDD._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.{SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 1f9f7118c3f04..5c1102410879a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -22,13 +22,16 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.test.SQLTestUtils -class ColumnExpressionSuite extends QueryTest { +class ColumnExpressionSuite extends QueryTest with SQLTestUtils { import org.apache.spark.sql.TestData._ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + override def sqlContext(): SQLContext = ctx + test("alias") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") assert(df.select(df("a").as("b")).columns.head === "b") @@ -489,6 +492,18 @@ class ColumnExpressionSuite extends QueryTest { ) } + test("InputFileName") { + withTempPath { dir => + val data = sqlContext.sparkContext.parallelize(0 to 10).toDF("id") + data.write.parquet(dir.getCanonicalPath) + val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName()) + .head.getString(0) + assert(answer.contains(dir.getCanonicalPath)) + + checkAnswer(data.select(inputFileName()).limit(1), Row("")) + } + } + test("lift alias out of cast") { compareExpressions( col("1234").as("name").cast("int").expr, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index d9c8b380ef146..183dc3407b3ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql +import org.apache.spark.sql.test.SQLTestUtils case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest { +class UDFSuite extends QueryTest with SQLTestUtils { private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + override def sqlContext(): SQLContext = ctx + test("built-in fixed arity expressions") { val df = ctx.emptyDataFrame df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") @@ -58,6 +61,18 @@ class UDFSuite extends QueryTest { ctx.dropTempTable("tmp_table") } + test("SPARK-8005 input_file_name") { + withTempPath { dir => + val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id") + data.write.parquet(dir.getCanonicalPath) + ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") + val answer = ctx.sql("select input_file_name() from test_table").head().getString(0) + assert(answer.contains(dir.getCanonicalPath)) + assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2) + ctx.dropTempTable("test_table") + } + } + test("error reporting for incorrect number of arguments") { val df = ctx.emptyDataFrame val e = intercept[AnalysisException] { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 37afc2142abf7..9b3ede43ee2d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -34,10 +34,4 @@ class UDFSuite extends QueryTest { assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } - - test("SPARK-8003 spark_partition_id") { - val df = Seq((1, "Two Fiiiiive")).toDF("id", "saying") - ctx.registerDataFrameAsTable(df, "test_table") - checkAnswer(ctx.sql("select spark_partition_id() from test_table LIMIT 1").toDF(), Row(0)) - } } From 76f2e393a5fad0db8b56c4b8dad5ef686bf140a4 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 30 Jul 2015 00:46:36 -0700 Subject: [PATCH 13/50] [SPARK-9335] [TESTS] Enable Kinesis tests only when files in extras/kinesis-asl are changed Author: zsxwing Closes #7711 from zsxwing/SPARK-9335-test and squashes the following commits: c13ec2f [zsxwing] environs -> environ 69c2865 [zsxwing] Merge remote-tracking branch 'origin/master' into SPARK-9335-test ef84a08 [zsxwing] Revert "Modify the Kinesis project to trigger ENABLE_KINESIS_TESTS" f691028 [zsxwing] Modify the Kinesis project to trigger ENABLE_KINESIS_TESTS 7618205 [zsxwing] Enable Kinesis tests only when files in extras/kinesis-asl are changed --- dev/run-tests.py | 16 ++++++++++++++++ dev/sparktestsupport/modules.py | 14 ++++++++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 1f0d218514f92..29420da9aa956 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -85,6 +85,13 @@ def identify_changed_files_from_git_commits(patch_sha, target_branch=None, targe return [f for f in raw_output.split('\n') if f] +def setup_test_environ(environ): + print("[info] Setup the following environment variables for tests: ") + for (k, v) in environ.items(): + print("%s=%s" % (k, v)) + os.environ[k] = v + + def determine_modules_to_test(changed_modules): """ Given a set of modules that have changed, compute the transitive closure of those modules' @@ -455,6 +462,15 @@ def main(): print("[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules)) + # setup environment variables + # note - the 'root' module doesn't collect environment variables for all modules. Because the + # environment variables should not be set if a module is not changed, even if running the 'root' + # module. So here we should use changed_modules rather than test_modules. + test_environ = {} + for m in changed_modules: + test_environ.update(m.environ) + setup_test_environ(test_environ) + test_modules = determine_modules_to_test(changed_modules) # license checks diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 3073d489bad4a..030d982e99106 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -29,7 +29,7 @@ class Module(object): changed. """ - def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), + def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), should_run_r_tests=False): """ @@ -43,6 +43,8 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= filename strings. :param build_profile_flags: A set of profile flags that should be passed to Maven or SBT in order to build and test this module (e.g. '-PprofileName'). + :param environ: A dict of environment variables that should be set when files in this + module are changed. :param sbt_test_goals: A set of SBT test goals for testing this module. :param python_test_goals: A set of Python test goals for testing this module. :param blacklisted_python_implementations: A set of Python implementations that are not @@ -55,6 +57,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.source_file_prefixes = source_file_regexes self.sbt_test_goals = sbt_test_goals self.build_profile_flags = build_profile_flags + self.environ = environ self.python_test_goals = python_test_goals self.blacklisted_python_implementations = blacklisted_python_implementations self.should_run_r_tests = should_run_r_tests @@ -126,15 +129,22 @@ def contains_file(self, filename): ) +# Don't set the dependencies because changes in other modules should not trigger Kinesis tests. +# Kinesis tests depends on external Amazon kinesis service. We should run these tests only when +# files in streaming_kinesis_asl are changed, so that if Kinesis experiences an outage, we don't +# fail other PRs. streaming_kinesis_asl = Module( name="kinesis-asl", - dependencies=[streaming], + dependencies=[], source_file_regexes=[ "extras/kinesis-asl/", ], build_profile_flags=[ "-Pkinesis-asl", ], + environ={ + "ENABLE_KINESIS_TESTS": "1" + }, sbt_test_goals=[ "kinesis-asl/test", ] From 4a8bb9d00d8181aff5f5183194d9aa2a65deacdf Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 30 Jul 2015 01:04:24 -0700 Subject: [PATCH 14/50] Revert "[SPARK-9458] Avoid object allocation in prefix generation." This reverts commit 9514d874f0cf61f1eb4ec4f5f66e053119f769c9. --- .../unsafe/sort/PrefixComparators.java | 16 ++++++ .../unsafe/sort/PrefixComparatorsSuite.scala | 12 +++++ .../execution/UnsafeExternalRowSorter.java | 2 +- .../spark/sql/execution/SortPrefixUtils.scala | 51 ++++++++++--------- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../org/apache/spark/sql/execution/sort.scala | 5 +- .../execution/RowFormatConvertersSuite.scala | 2 +- .../execution/UnsafeExternalSortSuite.scala | 10 ++-- 8 files changed, 67 insertions(+), 35 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index a9ee6042fec74..600aff7d15d8a 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -29,6 +29,7 @@ private PrefixComparators() {} public static final StringPrefixComparator STRING = new StringPrefixComparator(); public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator(); + public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator(); public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); public static final class StringPrefixComparator extends PrefixComparator { @@ -54,6 +55,21 @@ public int compare(long a, long b) { public final long NULL_PREFIX = Long.MIN_VALUE; } + public static final class FloatPrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + float a = Float.intBitsToFloat((int) aPrefix); + float b = Float.intBitsToFloat((int) bPrefix); + return Utils.nanSafeCompareFloats(a, b); + } + + public long computePrefix(float value) { + return Float.floatToIntBits(value) & 0xffffffffL; + } + + public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY); + } + public static final class DoublePrefixComparator extends PrefixComparator { @Override public int compare(long aPrefix, long bPrefix) { diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index 26b7a9e816d1e..cf53a8ad21c60 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -55,6 +55,18 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } + test("float prefix comparator handles NaN properly") { + val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001) + val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff) + assert(nan1.isNaN) + assert(nan2.isNaN) + val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1) + val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2) + assert(nan1Prefix === nan2Prefix) + val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue) + assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1) + } + test("double prefix comparator handles NaNs properly") { val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 8342833246f7d..4c3f2c6557140 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -121,7 +121,7 @@ Iterator sort() throws IOException { // here in order to prevent memory leaks. cleanupResources(); } - return new AbstractScalaRowIterator() { + return new AbstractScalaRowIterator() { private final int numFields = schema.length(); private UnsafeRow row = new UnsafeRow(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 050d27f1460fb..2dee3542d6101 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{BoundReference, SortOrder} +import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} @@ -39,54 +39,57 @@ object SortPrefixUtils { sortOrder.dataType match { case StringType => PrefixComparators.STRING case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL - case FloatType | DoubleType => PrefixComparators.DOUBLE + case FloatType => PrefixComparators.FLOAT + case DoubleType => PrefixComparators.DOUBLE case _ => NoOpPrefixComparator } } def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = { - val bound = sortOrder.child.asInstanceOf[BoundReference] - val pos = bound.ordinal sortOrder.dataType match { - case StringType => - (row: InternalRow) => { - PrefixComparators.STRING.computePrefix(row.getUTF8String(pos)) - } + case StringType => (row: InternalRow) => { + PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String]) + } case BooleanType => (row: InternalRow) => { - if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX - else if (row.getBoolean(pos)) 1 + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1 else 0 } case ByteType => (row: InternalRow) => { - if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX else row.getByte(pos) + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else sortOrder.child.eval(row).asInstanceOf[Byte] } case ShortType => (row: InternalRow) => { - if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX else row.getShort(pos) + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else sortOrder.child.eval(row).asInstanceOf[Short] } case IntegerType => (row: InternalRow) => { - if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX else row.getInt(pos) + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else sortOrder.child.eval(row).asInstanceOf[Int] } case LongType => (row: InternalRow) => { - if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX else row.getLong(pos) + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else sortOrder.child.eval(row).asInstanceOf[Long] } case FloatType => (row: InternalRow) => { - if (row.isNullAt(pos)) { - PrefixComparators.DOUBLE.NULL_PREFIX - } else { - PrefixComparators.DOUBLE.computePrefix(row.getFloat(pos).toDouble) - } + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX + else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float]) } case DoubleType => (row: InternalRow) => { - if (row.isNullAt(pos)) { - PrefixComparators.DOUBLE.NULL_PREFIX - } else { - PrefixComparators.DOUBLE.computePrefix(row.getDouble(pos)) - } + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX + else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double]) } case _ => (row: InternalRow) => 0L } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4ab2c41f1b339..f3ef066528ff8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -340,8 +340,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && - TungstenSort.supportsSchema(child.schema)) { - execution.TungstenSort(sortExprs, global, child) + UnsafeExternalSort.supportsSchema(child.schema)) { + execution.UnsafeExternalSort(sortExprs, global, child) } else if (sqlContext.conf.externalSortEnabled) { execution.ExternalSort(sortExprs, global, child) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index d0ad310062853..f82208868c3e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -97,7 +97,7 @@ case class ExternalSort( * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will * spill every `frequency` records. */ -case class TungstenSort( +case class UnsafeExternalSort( sortOrder: Seq[SortOrder], global: Boolean, child: SparkPlan, @@ -110,6 +110,7 @@ case class TungstenSort( if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { + assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled") def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { val ordering = newOrdering(sortOrder, child.output) val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output) @@ -148,7 +149,7 @@ case class TungstenSort( } @DeveloperApi -object TungstenSort { +object UnsafeExternalSort { /** * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index c458f95ca1ab3..7b75f755918c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -31,7 +31,7 @@ class RowFormatConvertersSuite extends SparkPlanTest { private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null)) + private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(outputsUnsafe.outputsUnsafeRows) test("planner should insert unsafe->safe conversions when required") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 9cabc4b90bf8e..7a4baa9e4a49d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -42,7 +42,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), sortAnswers = false ) @@ -53,7 +53,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { try { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), sortAnswers = false ) @@ -68,7 +68,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { val stringLength = 1024 * 1024 * 2 checkThatPlansAgree( Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), - TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) @@ -88,11 +88,11 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) - assert(TungstenSort.supportsSchema(inputDf.schema)) + assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) checkThatPlansAgree( inputDf, plan => ConvertToSafe( - TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), + UnsafeExternalSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) From 5ba2d44068b89fd8e81cfd24f49bf20d373f81b9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 30 Jul 2015 01:21:39 -0700 Subject: [PATCH 15/50] Fix flaky HashedRelationSuite SparkEnv might not have been set in local unit tests. Author: Reynold Xin Closes #7784 from rxin/HashedRelationSuite and squashes the following commits: 435d64b [Reynold Xin] Fix flaky HashedRelationSuite --- .../apache/spark/sql/execution/joins/HashedRelation.scala | 7 +++++-- .../spark/sql/execution/joins/HashedRelationSuite.scala | 6 +++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 7a507391316a9..26dbc911e9521 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -21,7 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} -import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.{SparkConf, SparkEnv, TaskContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -260,7 +260,10 @@ private[joins] final class UnsafeHashedRelation( val nKeys = in.readInt() // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) - val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") + + val pageSizeBytes = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + .getSizeAsBytes("spark.buffer.pageSize", "64m") + binaryMap = new BytesToBytesMap( memoryManager, nKeys * 2, // reduce hash collision diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 941f6d4f6a450..8b1a9b21a96b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -33,7 +33,7 @@ class HashedRelationSuite extends SparkFunSuite { override def apply(row: InternalRow): InternalRow = row } - ignore("GeneralHashedRelation") { + test("GeneralHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) @@ -47,7 +47,7 @@ class HashedRelationSuite extends SparkFunSuite { assert(hashed.get(data(2)) === data2) } - ignore("UniqueKeyHashedRelation") { + test("UniqueKeyHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) @@ -64,7 +64,7 @@ class HashedRelationSuite extends SparkFunSuite { assert(uniqHashed.getValue(InternalRow(10)) === null) } - ignore("UnsafeHashedRelation") { + test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val toUnsafe = UnsafeProjection.create(schema) From 6175d6cfe795fbd88e3ee713fac375038a3993a8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 Jul 2015 17:45:30 +0800 Subject: [PATCH 16/50] [SPARK-8838] [SQL] Add config to enable/disable merging part-files when merging parquet schema JIRA: https://issues.apache.org/jira/browse/SPARK-8838 Currently all part-files are merged when merging parquet schema. However, in case there are many part-files and we can make sure that all the part-files have the same schema as their summary file. If so, we provide a configuration to disable merging part-files when merging parquet schema. In short, we need to merge parquet schema because different summary files may contain different schema. But the part-files are confirmed to have the same schema with summary files. Author: Liang-Chi Hsieh Closes #7238 from viirya/option_partfile_merge and squashes the following commits: 71d5b5f [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into option_partfile_merge 8816f44 [Liang-Chi Hsieh] For comments. dbc8e6b [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into option_partfile_merge afc2fa1 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into option_partfile_merge d4ed7e6 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into option_partfile_merge df43027 [Liang-Chi Hsieh] Get dataStatuses' partitions based on all paths. 4eb2f00 [Liang-Chi Hsieh] Use given parameter. ea8f6e5 [Liang-Chi Hsieh] Correct the code comments. a57be0e [Liang-Chi Hsieh] Merge part-files if there are no summary files. 47df981 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into option_partfile_merge 4caf293 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into option_partfile_merge 0e734e0 [Liang-Chi Hsieh] Use correct API. 3b6be5b [Liang-Chi Hsieh] Fix key not found. 4bdd7e0 [Liang-Chi Hsieh] Don't read footer files if we can skip them. 8bbebcb [Liang-Chi Hsieh] Figure out how to test the config. bbd4ce7 [Liang-Chi Hsieh] Add config to enable/disable merging part-files when merging parquet schema. --- .../scala/org/apache/spark/sql/SQLConf.scala | 7 +++++ .../spark/sql/parquet/ParquetRelation.scala | 19 ++++++++++++- .../spark/sql/parquet/ParquetQuerySuite.scala | 27 +++++++++++++++++++ 3 files changed, 52 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index cdb0c7a1c07a7..2564bbd2077bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -247,6 +247,13 @@ private[spark] object SQLConf { "otherwise the schema is picked from the summary file or a random data file " + "if no summary file is available.") + val PARQUET_SCHEMA_RESPECT_SUMMARIES = booleanConf("spark.sql.parquet.respectSummaryFiles", + defaultValue = Some(false), + doc = "When true, we make assumption that all part-files of Parquet are consistent with " + + "summary files and we will ignore them when merging schema. Otherwise, if this is " + + "false, which is the default, we will merge all part-files. This should be considered " + + "as expert-only option, and shouldn't be enabled before knowing what it means exactly.") + val PARQUET_BINARY_AS_STRING = booleanConf("spark.sql.parquet.binaryAsString", defaultValue = Some(false), doc = "Some other Parquet-producing systems, in particular Impala and older versions of " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 1a8176d8a80ab..b4337a48dbd80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -124,6 +124,9 @@ private[sql] class ParquetRelation( .map(_.toBoolean) .getOrElse(sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) + private val mergeRespectSummaries = + sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES) + private val maybeMetastoreSchema = parameters .get(ParquetRelation.METASTORE_SCHEMA) .map(DataType.fromJson(_).asInstanceOf[StructType]) @@ -421,7 +424,21 @@ private[sql] class ParquetRelation( val filesToTouch = if (shouldMergeSchemas) { // Also includes summary files, 'cause there might be empty partition directories. - (metadataStatuses ++ commonMetadataStatuses ++ dataStatuses).toSeq + + // If mergeRespectSummaries config is true, we assume that all part-files are the same for + // their schema with summary files, so we ignore them when merging schema. + // If the config is disabled, which is the default setting, we merge all part-files. + // In this mode, we only need to merge schemas contained in all those summary files. + // You should enable this configuration only if you are very sure that for the parquet + // part-files to read there are corresponding summary files containing correct schema. + + val needMerged: Seq[FileStatus] = + if (mergeRespectSummaries) { + Seq() + } else { + dataStatuses + } + (metadataStatuses ++ commonMetadataStatuses ++ needMerged).toSeq } else { // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet // don't have this. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index c037faf4cfd92..a95f70f2bba69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.parquet +import java.io.File + import org.apache.hadoop.fs.Path import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.util.Utils /** * A test suite that tests various Parquet queries. @@ -123,6 +126,30 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { } } + test("Enabling/disabling merging partfiles when merging parquet schema") { + def testSchemaMerging(expectedColumnNumber: Int): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + // delete summary files, so if we don't merge part-files, one column will not be included. + Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) + Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) + assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + } + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "true") { + testSchemaMerging(2) + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "false") { + testSchemaMerging(3) + } + } + test("Enabling/disabling schema merging") { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => From d31c618e3c8838f8198556876b9dcbbbf835f7b2 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 30 Jul 2015 07:49:10 -0700 Subject: [PATCH 17/50] [SPARK-7368] [MLLIB] Add QR decomposition for RowMatrix jira: https://issues.apache.org/jira/browse/SPARK-7368 Add QR decomposition for RowMatrix. I'm not sure what's the blueprint about the distributed Matrix from community and whether this will be a desirable feature , so I sent a prototype for discussion. I'll go on polish the code and provide ut and performance statistics if it's acceptable. The implementation refers to the [paper: https://www.cs.purdue.edu/homes/dgleich/publications/Benson%202013%20-%20direct-tsqr.pdf] Austin R. Benson, David F. Gleich, James Demmel. "Direct QR factorizations for tall-and-skinny matrices in MapReduce architectures", 2013 IEEE International Conference on Big Data, which is a stable algorithm with good scalability. Currently I tried it on a 400000 * 500 rowMatrix (16 partitions) and it can bring down the computation time from 8.8 mins (using breeze.linalg.qr.reduced) to 2.6 mins on a 4 worker cluster. I think there will still be some room for performance improvement. Any trial and suggestion is welcome. Author: Yuhao Yang Closes #5909 from hhbyyh/qrDecomposition and squashes the following commits: cec797b [Yuhao Yang] remove unnecessary qr 0fb1012 [Yuhao Yang] hierarchy R computing 3fbdb61 [Yuhao Yang] update qr to indirect and add ut 0d913d3 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into qrDecomposition 39213c3 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into qrDecomposition c0fc0c7 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into qrDecomposition 39b0b22 [Yuhao Yang] initial draft for discussion --- .../linalg/SingularValueDecomposition.scala | 8 ++++ .../mllib/linalg/distributed/RowMatrix.scala | 46 ++++++++++++++++++- .../linalg/distributed/RowMatrixSuite.scala | 17 +++++++ 3 files changed, 70 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala index 9669c364bad8f..b416d50a5631e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala @@ -25,3 +25,11 @@ import org.apache.spark.annotation.Experimental */ @Experimental case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VType) + +/** + * :: Experimental :: + * Represents QR factors. + */ +@Experimental +case class QRDecomposition[UType, VType](Q: UType, R: VType) + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 1626da9c3d2ee..bfc90c9ef8527 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -22,7 +22,7 @@ import java.util.Arrays import scala.collection.mutable.ListBuffer import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, - svd => brzSvd} + svd => brzSvd, MatrixSingularException, inv} import breeze.numerics.{sqrt => brzSqrt} import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -497,6 +497,50 @@ class RowMatrix( columnSimilaritiesDIMSUM(computeColumnSummaryStatistics().normL2.toArray, gamma) } + /** + * Compute QR decomposition for [[RowMatrix]]. The implementation is designed to optimize the QR + * decomposition (factorization) for the [[RowMatrix]] of a tall and skinny shape. + * Reference: + * Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations in MapReduce + * architectures" ([[http://dx.doi.org/10.1145/1996092.1996103]]) + * + * @param computeQ whether to computeQ + * @return QRDecomposition(Q, R), Q = null if computeQ = false. + */ + def tallSkinnyQR(computeQ: Boolean = false): QRDecomposition[RowMatrix, Matrix] = { + val col = numCols().toInt + // split rows horizontally into smaller matrices, and compute QR for each of them + val blockQRs = rows.glom().map { partRows => + val bdm = BDM.zeros[Double](partRows.length, col) + var i = 0 + partRows.foreach { row => + bdm(i, ::) := row.toBreeze.t + i += 1 + } + breeze.linalg.qr.reduced(bdm).r + } + + // combine the R part from previous results vertically into a tall matrix + val combinedR = blockQRs.treeReduce{ (r1, r2) => + val stackedR = BDM.vertcat(r1, r2) + breeze.linalg.qr.reduced(stackedR).r + } + val finalR = Matrices.fromBreeze(combinedR.toDenseMatrix) + val finalQ = if (computeQ) { + try { + val invR = inv(combinedR) + this.multiply(Matrices.fromBreeze(invR)) + } catch { + case err: MatrixSingularException => + logWarning("R is not invertible and return Q as null") + null + } + } else { + null + } + QRDecomposition(finalQ, finalR) + } + /** * Find all similar columns using the DIMSUM sampling algorithm, described in two papers * diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index b6cb53d0c743e..283ffec1d49d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.linalg.distributed import scala.util.Random +import breeze.numerics.abs import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd} import org.apache.spark.SparkFunSuite @@ -238,6 +239,22 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("QR Decomposition") { + for (mat <- Seq(denseMat, sparseMat)) { + val result = mat.tallSkinnyQR(true) + val expected = breeze.linalg.qr.reduced(mat.toBreeze()) + val calcQ = result.Q + val calcR = result.R + assert(closeToZero(abs(expected.q) - abs(calcQ.toBreeze()))) + assert(closeToZero(abs(expected.r) - abs(calcR.toBreeze.asInstanceOf[BDM[Double]]))) + assert(closeToZero(calcQ.multiply(calcR).toBreeze - mat.toBreeze())) + // Decomposition without computing Q + val rOnly = mat.tallSkinnyQR(computeQ = false) + assert(rOnly.Q == null) + assert(closeToZero(abs(expected.r) - abs(rOnly.R.toBreeze.asInstanceOf[BDM[Double]]))) + } + } } class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext { From c5815930be46a89469440b7c61b59764fb67a54c Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 30 Jul 2015 07:56:15 -0700 Subject: [PATCH 18/50] [SPARK-5561] [MLLIB] Generalized PeriodicCheckpointer for RDDs and Graphs PeriodicGraphCheckpointer was introduced for Latent Dirichlet Allocation (LDA), but it was meant to be generalized to work with Graphs, RDDs, and other data structures based on RDDs. This PR generalizes it. For those who are not familiar with the periodic checkpointer, it tries to automatically handle persisting/unpersisting and checkpointing/removing checkpoint files in a lineage of RDD-based objects. I need it generalized to use with GradientBoostedTrees [https://issues.apache.org/jira/browse/SPARK-6684]. It should be useful for other iterative algorithms as well. Changes I made: * Copied PeriodicGraphCheckpointer to PeriodicCheckpointer. * Within PeriodicCheckpointer, I created abstract methods for the basic operations (checkpoint, persist, etc.). * The subclasses for Graphs and RDDs implement those abstract methods. * I copied the test suite for the graph checkpointer and made tiny modifications to make it work for RDDs. To review this PR, I recommend doing 2 diffs: (1) diff between the old PeriodicGraphCheckpointer.scala and the new PeriodicCheckpointer.scala (2) diff between the 2 test suites CCing andrewor14 in case there are relevant changes to checkpointing. CCing feynmanliang in case you're interested in learning about checkpointing. CCing mengxr for final OK. Thanks all! Author: Joseph K. Bradley Closes #7728 from jkbradley/gbt-checkpoint and squashes the following commits: d41902c [Joseph K. Bradley] Oops, forgot to update an extra time in the checkpointer tests, after the last commit. I'll fix that. I'll also make some of the checkpointer methods protected, which I should have done before. 32b23b8 [Joseph K. Bradley] fixed usage of checkpointer in lda 0b3dbc0 [Joseph K. Bradley] Changed checkpointer constructor not to take initial data. 568918c [Joseph K. Bradley] Generalized PeriodicGraphCheckpointer to PeriodicCheckpointer, with subclasses for RDDs and Graphs. --- .../spark/mllib/clustering/LDAOptimizer.scala | 6 +- .../mllib/impl/PeriodicCheckpointer.scala | 154 ++++++++++++++++ .../impl/PeriodicGraphCheckpointer.scala | 105 ++--------- .../mllib/impl/PeriodicRDDCheckpointer.scala | 97 ++++++++++ .../impl/PeriodicGraphCheckpointerSuite.scala | 16 +- .../impl/PeriodicRDDCheckpointerSuite.scala | 173 ++++++++++++++++++ 6 files changed, 452 insertions(+), 99 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 7e75e7083acb5..4b90fbdf0ce7e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -142,8 +142,8 @@ final class EMLDAOptimizer extends LDAOptimizer { this.k = k this.vocabSize = docs.take(1).head._2.size this.checkpointInterval = lda.getCheckpointInterval - this.graphCheckpointer = new - PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval) + this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( + checkpointInterval, graph.vertices.sparkContext) this.globalTopicTotals = computeGlobalTopicTotals() this } @@ -188,7 +188,7 @@ final class EMLDAOptimizer extends LDAOptimizer { // Update the vertex descriptors with the new counts. val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) graph = newGraph - graphCheckpointer.updateGraph(newGraph) + graphCheckpointer.update(newGraph) globalTopicTotals = computeGlobalTopicTotals() this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala new file mode 100644 index 0000000000000..72d3aabc9b1f4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import scala.collection.mutable + +import org.apache.hadoop.fs.{Path, FileSystem} + +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.storage.StorageLevel + + +/** + * This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs + * (such as Graphs and DataFrames). In documentation, we use the phrase "Dataset" to refer to + * the distributed data type (RDD, Graph, etc.). + * + * Specifically, this abstraction automatically handles persisting and (optionally) checkpointing, + * as well as unpersisting and removing checkpoint files. + * + * Users should call update() when a new Dataset has been created, + * before the Dataset has been materialized. After updating [[PeriodicCheckpointer]], users are + * responsible for materializing the Dataset to ensure that persisting and checkpointing actually + * occur. + * + * When update() is called, this does the following: + * - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets. + * - Unpersist Datasets from queue until there are at most 3 persisted Datasets. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which Datasets should be + * checkpointed). + * - This class removes checkpoint files once later Datasets have been checkpointed. + * However, references to the older Datasets will still return isCheckpointed = true. + * + * @param checkpointInterval Datasets will be checkpointed at this interval + * @param sc SparkContext for the Datasets given to this checkpointer + * @tparam T Dataset type, such as RDD[Double] + */ +private[mllib] abstract class PeriodicCheckpointer[T]( + val checkpointInterval: Int, + val sc: SparkContext) extends Logging { + + /** FIFO queue of past checkpointed Datasets */ + private val checkpointQueue = mutable.Queue[T]() + + /** FIFO queue of past persisted Datasets */ + private val persistedQueue = mutable.Queue[T]() + + /** Number of times [[update()]] has been called */ + private var updateCount = 0 + + /** + * Update with a new Dataset. Handle persistence and checkpointing as needed. + * Since this handles persistence and checkpointing, this should be called before the Dataset + * has been materialized. + * + * @param newData New Dataset created from previous Datasets in the lineage. + */ + def update(newData: T): Unit = { + persist(newData) + persistedQueue.enqueue(newData) + // We try to maintain 2 Datasets in persistedQueue to support the semantics of this class: + // Users should call [[update()]] when a new Dataset has been created, + // before the Dataset has been materialized. + while (persistedQueue.size > 3) { + val dataToUnpersist = persistedQueue.dequeue() + unpersist(dataToUnpersist) + } + updateCount += 1 + + // Handle checkpointing (after persisting) + if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { + // Add new checkpoint before removing old checkpoints. + checkpoint(newData) + checkpointQueue.enqueue(newData) + // Remove checkpoints before the latest one. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // Delete the oldest checkpoint only if the next checkpoint exists. + if (isCheckpointed(checkpointQueue.head)) { + removeCheckpointFile() + } else { + canDelete = false + } + } + } + } + + /** Checkpoint the Dataset */ + protected def checkpoint(data: T): Unit + + /** Return true iff the Dataset is checkpointed */ + protected def isCheckpointed(data: T): Boolean + + /** + * Persist the Dataset. + * Note: This should handle checking the current [[StorageLevel]] of the Dataset. + */ + protected def persist(data: T): Unit + + /** Unpersist the Dataset */ + protected def unpersist(data: T): Unit + + /** Get list of checkpoint files for this given Dataset */ + protected def getCheckpointFiles(data: T): Iterable[String] + + /** + * Call this at the end to delete any remaining checkpoint files. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.nonEmpty) { + removeCheckpointFile() + } + } + + /** + * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files. + * This prints a warning but does not fail if the files cannot be removed. + */ + private def removeCheckpointFile(): Unit = { + val old = checkpointQueue.dequeue() + // Since the old checkpoint is not deleted by Spark, we manually delete it. + val fs = FileSystem.get(sc.hadoopConfiguration) + getCheckpointFiles(old).foreach { checkpointFile => + try { + fs.delete(new Path(checkpointFile), true) + } catch { + case e: Exception => + logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + + checkpointFile) + } + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala index 6e5dd119dd653..11a059536c50c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -17,11 +17,7 @@ package org.apache.spark.mllib.impl -import scala.collection.mutable - -import org.apache.hadoop.fs.{Path, FileSystem} - -import org.apache.spark.Logging +import org.apache.spark.SparkContext import org.apache.spark.graphx.Graph import org.apache.spark.storage.StorageLevel @@ -31,12 +27,12 @@ import org.apache.spark.storage.StorageLevel * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as * unpersisting and removing checkpoint files. * - * Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created, + * Users should call update() when a new graph has been created, * before the graph has been materialized. After updating [[PeriodicGraphCheckpointer]], users are * responsible for materializing the graph to ensure that persisting and checkpointing actually * occur. * - * When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following: + * When update() is called, this does the following: * - Persist new graph (if not yet persisted), and put in queue of persisted graphs. * - Unpersist graphs from queue until there are at most 3 persisted graphs. * - If using checkpointing and the checkpoint interval has been reached, @@ -52,7 +48,7 @@ import org.apache.spark.storage.StorageLevel * Example usage: * {{{ * val (graph1, graph2, graph3, ...) = ... - * val cp = new PeriodicGraphCheckpointer(graph1, dir, 2) + * val cp = new PeriodicGraphCheckpointer(2, sc) * graph1.vertices.count(); graph1.edges.count() * // persisted: graph1 * cp.updateGraph(graph2) @@ -73,99 +69,30 @@ import org.apache.spark.storage.StorageLevel * // checkpointed: graph4 * }}} * - * @param currentGraph Initial graph * @param checkpointInterval Graphs will be checkpointed at this interval * @tparam VD Vertex descriptor type * @tparam ED Edge descriptor type * - * TODO: Generalize this for Graphs and RDDs, and move it out of MLlib. + * TODO: Move this out of MLlib? */ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( - var currentGraph: Graph[VD, ED], - val checkpointInterval: Int) extends Logging { - - /** FIFO queue of past checkpointed RDDs */ - private val checkpointQueue = mutable.Queue[Graph[VD, ED]]() - - /** FIFO queue of past persisted RDDs */ - private val persistedQueue = mutable.Queue[Graph[VD, ED]]() - - /** Number of times [[updateGraph()]] has been called */ - private var updateCount = 0 - - /** - * Spark Context for the Graphs given to this checkpointer. - * NOTE: This code assumes that only one SparkContext is used for the given graphs. - */ - private val sc = currentGraph.vertices.sparkContext + checkpointInterval: Int, + sc: SparkContext) + extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) { - updateGraph(currentGraph) + override protected def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint() - /** - * Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed. - * Since this handles persistence and checkpointing, this should be called before the graph - * has been materialized. - * - * @param newGraph New graph created from previous graphs in the lineage. - */ - def updateGraph(newGraph: Graph[VD, ED]): Unit = { - if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) { - newGraph.persist() - } - persistedQueue.enqueue(newGraph) - // We try to maintain 2 Graphs in persistedQueue to support the semantics of this class: - // Users should call [[updateGraph()]] when a new graph has been created, - // before the graph has been materialized. - while (persistedQueue.size > 3) { - val graphToUnpersist = persistedQueue.dequeue() - graphToUnpersist.unpersist(blocking = false) - } - updateCount += 1 + override protected def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed - // Handle checkpointing (after persisting) - if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { - // Add new checkpoint before removing old checkpoints. - newGraph.checkpoint() - checkpointQueue.enqueue(newGraph) - // Remove checkpoints before the latest one. - var canDelete = true - while (checkpointQueue.size > 1 && canDelete) { - // Delete the oldest checkpoint only if the next checkpoint exists. - if (checkpointQueue.get(1).get.isCheckpointed) { - removeCheckpointFile() - } else { - canDelete = false - } - } + override protected def persist(data: Graph[VD, ED]): Unit = { + if (data.vertices.getStorageLevel == StorageLevel.NONE) { + data.persist() } } - /** - * Call this at the end to delete any remaining checkpoint files. - */ - def deleteAllCheckpoints(): Unit = { - while (checkpointQueue.size > 0) { - removeCheckpointFile() - } - } + override protected def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false) - /** - * Dequeue the oldest checkpointed Graph, and remove its checkpoint files. - * This prints a warning but does not fail if the files cannot be removed. - */ - private def removeCheckpointFile(): Unit = { - val old = checkpointQueue.dequeue() - // Since the old checkpoint is not deleted by Spark, we manually delete it. - val fs = FileSystem.get(sc.hadoopConfiguration) - old.getCheckpointFiles.foreach { checkpointFile => - try { - fs.delete(new Path(checkpointFile), true) - } catch { - case e: Exception => - logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " + - checkpointFile) - } - } + override protected def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = { + data.getCheckpointFiles } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala new file mode 100644 index 0000000000000..f31ed2aa90a64 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + + +/** + * This class helps with persisting and checkpointing RDDs. + * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as + * unpersisting and removing checkpoint files. + * + * Users should call update() when a new RDD has been created, + * before the RDD has been materialized. After updating [[PeriodicRDDCheckpointer]], users are + * responsible for materializing the RDD to ensure that persisting and checkpointing actually + * occur. + * + * When update() is called, this does the following: + * - Persist new RDD (if not yet persisted), and put in queue of persisted RDDs. + * - Unpersist RDDs from queue until there are at most 3 persisted RDDs. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new RDD, and put in a queue of checkpointed RDDs. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which RDDs should be + * checkpointed). + * - This class removes checkpoint files once later RDDs have been checkpointed. + * However, references to the older RDDs will still return isCheckpointed = true. + * + * Example usage: + * {{{ + * val (rdd1, rdd2, rdd3, ...) = ... + * val cp = new PeriodicRDDCheckpointer(2, sc) + * rdd1.count(); + * // persisted: rdd1 + * cp.update(rdd2) + * rdd2.count(); + * // persisted: rdd1, rdd2 + * // checkpointed: rdd2 + * cp.update(rdd3) + * rdd3.count(); + * // persisted: rdd1, rdd2, rdd3 + * // checkpointed: rdd2 + * cp.update(rdd4) + * rdd4.count(); + * // persisted: rdd2, rdd3, rdd4 + * // checkpointed: rdd4 + * cp.update(rdd5) + * rdd5.count(); + * // persisted: rdd3, rdd4, rdd5 + * // checkpointed: rdd4 + * }}} + * + * @param checkpointInterval RDDs will be checkpointed at this interval + * @tparam T RDD element type + * + * TODO: Move this out of MLlib? + */ +private[mllib] class PeriodicRDDCheckpointer[T]( + checkpointInterval: Int, + sc: SparkContext) + extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) { + + override protected def checkpoint(data: RDD[T]): Unit = data.checkpoint() + + override protected def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed + + override protected def persist(data: RDD[T]): Unit = { + if (data.getStorageLevel == StorageLevel.NONE) { + data.persist() + } + } + + override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false) + + override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = { + data.getCheckpointFile.map(x => x) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index d34888af2d73b..e331c75989187 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -30,20 +30,20 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo import PeriodicGraphCheckpointerSuite._ - // TODO: Do I need to call count() on the graphs' RDDs? - test("Persisting") { var graphsToCheck = Seq.empty[GraphToCheck] val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer(graph1, 10) + val checkpointer = + new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) + checkpointer.update(graph1) graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) checkPersistence(graphsToCheck, 1) var iteration = 2 while (iteration < 9) { val graph = createGraph(sc) - checkpointer.updateGraph(graph) + checkpointer.update(graph) graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) checkPersistence(graphsToCheck, iteration) iteration += 1 @@ -57,7 +57,9 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo var graphsToCheck = Seq.empty[GraphToCheck] sc.setCheckpointDir(path) val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer(graph1, checkpointInterval) + val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( + checkpointInterval, graph1.vertices.sparkContext) + checkpointer.update(graph1) graph1.edges.count() graph1.vertices.count() graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) @@ -66,7 +68,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo var iteration = 2 while (iteration < 9) { val graph = createGraph(sc) - checkpointer.updateGraph(graph) + checkpointer.update(graph) graph.vertices.count() graph.edges.count() graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) @@ -168,7 +170,7 @@ private object PeriodicGraphCheckpointerSuite { } else { // Graph should never be checkpointed assert(!graph.isCheckpointed, "Graph should never have been checkpointed") - assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files") + assert(graph.getCheckpointFiles.isEmpty, "Graph should not have any checkpoint files") } } catch { case e: AssertionError => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala new file mode 100644 index 0000000000000..b2a459a68b5fa --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + + +class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { + + import PeriodicRDDCheckpointerSuite._ + + test("Persisting") { + var rddsToCheck = Seq.empty[RDDToCheck] + + val rdd1 = createRDD(sc) + val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext) + checkpointer.update(rdd1) + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) + checkPersistence(rddsToCheck, 1) + + var iteration = 2 + while (iteration < 9) { + val rdd = createRDD(sc) + checkpointer.update(rdd) + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) + checkPersistence(rddsToCheck, iteration) + iteration += 1 + } + } + + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var rddsToCheck = Seq.empty[RDDToCheck] + sc.setCheckpointDir(path) + val rdd1 = createRDD(sc) + val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext) + checkpointer.update(rdd1) + rdd1.count() + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) + checkCheckpoint(rddsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val rdd = createRDD(sc) + checkpointer.update(rdd) + rdd.count() + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) + checkCheckpoint(rddsToCheck, iteration, checkpointInterval) + iteration += 1 + } + + checkpointer.deleteAllCheckpoints() + rddsToCheck.foreach { rdd => + confirmCheckpointRemoved(rdd.rdd) + } + + Utils.deleteRecursively(tempDir) + } +} + +private object PeriodicRDDCheckpointerSuite { + + case class RDDToCheck(rdd: RDD[Double], gIndex: Int) + + def createRDD(sc: SparkContext): RDD[Double] = { + sc.parallelize(Seq(0.0, 1.0, 2.0, 3.0)) + } + + def checkPersistence(rdds: Seq[RDDToCheck], iteration: Int): Unit = { + rdds.foreach { g => + checkPersistence(g.rdd, g.gIndex, iteration) + } + } + + /** + * Check storage level of rdd. + * @param gIndex Index of rdd in order inserted into checkpointer (from 1). + * @param iteration Total number of rdds inserted into checkpointer. + */ + def checkPersistence(rdd: RDD[_], gIndex: Int, iteration: Int): Unit = { + try { + if (gIndex + 2 < iteration) { + assert(rdd.getStorageLevel == StorageLevel.NONE) + } else { + assert(rdd.getStorageLevel != StorageLevel.NONE) + } + } catch { + case _: AssertionError => + throw new Exception(s"PeriodicRDDCheckpointerSuite.checkPersistence failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t rdd.getStorageLevel = ${rdd.getStorageLevel}\n") + } + } + + def checkCheckpoint(rdds: Seq[RDDToCheck], iteration: Int, checkpointInterval: Int): Unit = { + rdds.reverse.foreach { g => + checkCheckpoint(g.rdd, g.gIndex, iteration, checkpointInterval) + } + } + + def confirmCheckpointRemoved(rdd: RDD[_]): Unit = { + // Note: We cannot check rdd.isCheckpointed since that value is never updated. + // Instead, we check for the presence of the checkpoint files. + // This test should continue to work even after this rdd.isCheckpointed issue + // is fixed (though it can then be simplified and not look for the files). + val fs = FileSystem.get(rdd.sparkContext.hadoopConfiguration) + rdd.getCheckpointFile.foreach { checkpointFile => + assert(!fs.exists(new Path(checkpointFile)), "RDD checkpoint file should have been removed") + } + } + + /** + * Check checkpointed status of rdd. + * @param gIndex Index of rdd in order inserted into checkpointer (from 1). + * @param iteration Total number of rdds inserted into checkpointer. + */ + def checkCheckpoint( + rdd: RDD[_], + gIndex: Int, + iteration: Int, + checkpointInterval: Int): Unit = { + try { + if (gIndex % checkpointInterval == 0) { + // We allow 2 checkpoint intervals since we perform an action (checkpointing a second rdd) + // only AFTER PeriodicRDDCheckpointer decides whether to remove the previous checkpoint. + if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) { + assert(rdd.isCheckpointed, "RDD should be checkpointed") + assert(rdd.getCheckpointFile.nonEmpty, "RDD should have 2 checkpoint files") + } else { + confirmCheckpointRemoved(rdd) + } + } else { + // RDD should never be checkpointed + assert(!rdd.isCheckpointed, "RDD should never have been checkpointed") + assert(rdd.getCheckpointFile.isEmpty, "RDD should not have any checkpoint files") + } + } catch { + case e: AssertionError => + throw new Exception(s"PeriodicRDDCheckpointerSuite.checkCheckpoint failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t checkpointInterval = $checkpointInterval\n" + + s"\t rdd.isCheckpointed = ${rdd.isCheckpointed}\n" + + s"\t rdd.getCheckpointFile = ${rdd.getCheckpointFile.mkString(", ")}\n" + + s" AssertionError message: ${e.getMessage}") + } + } + +} From d212a314227dec26c0dbec8ed3422d0ec8f818f9 Mon Sep 17 00:00:00 2001 From: zhangjiajin Date: Thu, 30 Jul 2015 08:14:09 -0700 Subject: [PATCH 19/50] [SPARK-8998] [MLLIB] Distribute PrefixSpan computation for large projected databases Continuation of work by zhangjiajin Closes #7412 Author: zhangjiajin Author: Feynman Liang Author: zhang jiajin Closes #7783 from feynmanliang/SPARK-8998-improve-distributed and squashes the following commits: a61943d [Feynman Liang] Collect small patterns to local 4ddf479 [Feynman Liang] Parallelize freqItemCounts ad23aa9 [zhang jiajin] Merge pull request #1 from feynmanliang/SPARK-8998-collectBeforeLocal 87fa021 [Feynman Liang] Improve extend prefix readability c2caa5c [Feynman Liang] Readability improvements and comments 1235cfc [Feynman Liang] Use Iterable[Array[_]] over Array[Array[_]] for database da0091b [Feynman Liang] Use lists for prefixes to reuse data cb2a4fc [Feynman Liang] Inline code for readability 01c9ae9 [Feynman Liang] Add getters 6e149fa [Feynman Liang] Fix splitPrefixSuffixPairs 64271b3 [zhangjiajin] Modified codes according to comments. d2250b7 [zhangjiajin] remove minPatternsBeforeLocalProcessing, add maxSuffixesBeforeLocalProcessing. b07e20c [zhangjiajin] Merge branch 'master' of https://github.com/apache/spark into CollectEnoughPrefixes 095aa3a [zhangjiajin] Modified the code according to the review comments. baa2885 [zhangjiajin] Modified the code according to the review comments. 6560c69 [zhangjiajin] Add feature: Collect enough frequent prefixes before projection in PrefixeSpan a8fde87 [zhangjiajin] Merge branch 'master' of https://github.com/apache/spark 4dd1c8a [zhangjiajin] initialize file before rebase. 078d410 [zhangjiajin] fix a scala style error. 22b0ef4 [zhangjiajin] Add feature: Collect enough frequent prefixes before projection in PrefixSpan. ca9c4c8 [zhangjiajin] Modified the code according to the review comments. 574e56c [zhangjiajin] Add new object LocalPrefixSpan, and do some optimization. ba5df34 [zhangjiajin] Fix a Scala style error. 4c60fb3 [zhangjiajin] Fix some Scala style errors. 1dd33ad [zhangjiajin] Modified the code according to the review comments. 89bc368 [zhangjiajin] Fixed a Scala style error. a2eb14c [zhang jiajin] Delete PrefixspanSuite.scala 951fd42 [zhang jiajin] Delete Prefixspan.scala 575995f [zhangjiajin] Modified the code according to the review comments. 91fd7e6 [zhangjiajin] Add new algorithm PrefixSpan and test file. --- .../spark/mllib/fpm/LocalPrefixSpan.scala | 6 +- .../apache/spark/mllib/fpm/PrefixSpan.scala | 203 +++++++++++++----- .../spark/mllib/fpm/PrefixSpanSuite.scala | 21 +- 3 files changed, 161 insertions(+), 69 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 7ead6327486cc..0ea792081086d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -40,7 +40,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { minCount: Long, maxPatternLength: Int, prefixes: List[Int], - database: Array[Array[Int]]): Iterator[(List[Int], Long)] = { + database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = { if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains)) @@ -67,7 +67,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { } } - def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = { + def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = { database .map(getSuffix(prefix, _)) .filter(_.nonEmpty) @@ -81,7 +81,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { */ private def getFreqItemAndCounts( minCount: Long, - database: Array[Array[Int]]): mutable.Map[Int, Long] = { + database: Iterable[Array[Int]]): mutable.Map[Int, Long] = { // TODO: use PrimitiveKeyOpenHashMap val counts = mutable.Map[Int, Long]().withDefaultValue(0L) database.foreach { sequence => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 6f52db7b073ae..e6752332cdeeb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.fpm +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD @@ -43,28 +45,45 @@ class PrefixSpan private ( private var minSupport: Double, private var maxPatternLength: Int) extends Logging with Serializable { + /** + * The maximum number of items allowed in a projected database before local processing. If a + * projected database exceeds this size, another iteration of distributed PrefixSpan is run. + */ + // TODO: make configurable with a better default value, 10000 may be too small + private val maxLocalProjDBSize: Long = 10000 + /** * Constructs a default instance with default parameters * {minSupport: `0.1`, maxPatternLength: `10`}. */ def this() = this(0.1, 10) + /** + * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered + * frequent). + */ + def getMinSupport: Double = this.minSupport + /** * Sets the minimal support level (default: `0.1`). */ def setMinSupport(minSupport: Double): this.type = { - require(minSupport >= 0 && minSupport <= 1, - "The minimum support value must be between 0 and 1, including 0 and 1.") + require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].") this.minSupport = minSupport this } + /** + * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider. + */ + def getMaxPatternLength: Double = this.maxPatternLength + /** * Sets maximal pattern length (default: `10`). */ def setMaxPatternLength(maxPatternLength: Int): this.type = { - require(maxPatternLength >= 1, - "The maximum pattern length value must be greater than 0.") + // TODO: support unbounded pattern length when maxPatternLength = 0 + require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.") this.maxPatternLength = maxPatternLength this } @@ -78,81 +97,153 @@ class PrefixSpan private ( * the value of pair is the pattern's count. */ def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = { + val sc = sequences.sparkContext + if (sequences.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } - val minCount = getMinCount(sequences) - val lengthOnePatternsAndCounts = - getFreqItemAndCounts(minCount, sequences).collect() - val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase( - lengthOnePatternsAndCounts.map(_._1), sequences) - val groupedProjectedDatabase = prefixAndProjectedDatabase - .map(x => (x._1.toSeq, x._2)) - .groupByKey() - .map(x => (x._1.toArray, x._2.toArray)) - val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase) - val lengthOnePatternsAndCountsRdd = - sequences.sparkContext.parallelize( - lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))) - val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns - allPatterns + + // Convert min support to a min number of transactions for this dataset + val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong + + // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold + val freqItemCounts = sequences + .flatMap(seq => seq.distinct.map(item => (item, 1L))) + .reduceByKey(_ + _) + .filter(_._2 >= minCount) + .collect() + + // Pairs of (length 1 prefix, suffix consisting of frequent items) + val itemSuffixPairs = { + val freqItems = freqItemCounts.map(_._1).toSet + sequences.flatMap { seq => + val filteredSeq = seq.filter(freqItems.contains(_)) + freqItems.flatMap { item => + val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq) + candidateSuffix match { + case suffix if !suffix.isEmpty => Some((List(item), suffix)) + case _ => None + } + } + } + } + + // Accumulator for the computed results to be returned, initialized to the frequent items (i.e. + // frequent length-one prefixes) + var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2)) + + // Remaining work to be locally and distributively processed respectfully + var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs) + + // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have + // projected database sizes <= `maxLocalProjDBSize`) + while (pairsForDistributed.count() != 0) { + val (nextPatternAndCounts, nextPrefixSuffixPairs) = + extendPrefixes(minCount, pairsForDistributed) + pairsForDistributed.unpersist() + val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs) + pairsForDistributed = largerPairsPart + pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK) + pairsForLocal ++= smallerPairsPart + resultsAccumulator ++= nextPatternAndCounts.collect() + } + + // Process the small projected databases locally + val remainingResults = getPatternsInLocal( + minCount, sc.parallelize(pairsForLocal, 1).groupByKey()) + + (sc.parallelize(resultsAccumulator, 1) ++ remainingResults) + .map { case (pattern, count) => (pattern.toArray, count) } } + /** - * Get the minimum count (sequences count * minSupport). - * @param sequences input data set, contains a set of sequences, - * @return minimum count, + * Partitions the prefix-suffix pairs by projected database size. + * @param prefixSuffixPairs prefix (length n) and suffix pairs, + * @return prefix-suffix pairs partitioned by whether their projected database size is <= or + * greater than [[maxLocalProjDBSize]] */ - private def getMinCount(sequences: RDD[Array[Int]]): Long = { - if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong + private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])]) + : (Array[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = { + val prefixToSuffixSize = prefixSuffixPairs + .aggregateByKey(0)( + seqOp = { case (count, suffix) => count + suffix.length }, + combOp = { _ + _ }) + val smallPrefixes = prefixToSuffixSize + .filter(_._2 <= maxLocalProjDBSize) + .keys + .collect() + .toSet + val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) } + val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) } + (small.collect(), large) } /** - * Generates frequent items by filtering the input data using minimal count level. - * @param minCount the absolute minimum count - * @param sequences original sequences data - * @return array of item and count pair + * Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes + * and remaining work. + * @param minCount minimum count + * @param prefixSuffixPairs prefix (length N) and suffix pairs, + * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended + * prefix, corresponding suffix) pairs. */ - private def getFreqItemAndCounts( + private def extendPrefixes( minCount: Long, - sequences: RDD[Array[Int]]): RDD[(Int, Long)] = { - sequences.flatMap(_.distinct.map((_, 1L))) + prefixSuffixPairs: RDD[(List[Int], Array[Int])]) + : (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = { + + // (length N prefix, item from suffix) pairs and their corresponding number of occurrences + // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport` + val prefixItemPairAndCounts = prefixSuffixPairs + .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) } .reduceByKey(_ + _) .filter(_._2 >= minCount) - } - /** - * Get the frequent prefixes' projected database. - * @param frequentPrefixes frequent prefixes - * @param sequences sequences data - * @return prefixes and projected database - */ - private def getPrefixAndProjectedDatabase( - frequentPrefixes: Array[Int], - sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = { - val filteredSequences = sequences.map { p => - p.filter (frequentPrefixes.contains(_) ) - } - filteredSequences.flatMap { x => - frequentPrefixes.map { y => - val sub = LocalPrefixSpan.getSuffix(y, x) - (Array(y), sub) - }.filter(_._2.nonEmpty) - } + // Map from prefix to set of possible next items from suffix + val prefixToNextItems = prefixItemPairAndCounts + .keys + .groupByKey() + .mapValues(_.toSet) + .collect() + .toMap + + + // Frequent patterns with length N+1 and their corresponding counts + val extendedPrefixAndCounts = prefixItemPairAndCounts + .map { case ((prefix, item), count) => (item :: prefix, count) } + + // Remaining work, all prefixes will have length N+1 + val extendedPrefixAndSuffix = prefixSuffixPairs + .filter(x => prefixToNextItems.contains(x._1)) + .flatMap { case (prefix, suffix) => + val frequentNextItems = prefixToNextItems(prefix) + val filteredSuffix = suffix.filter(frequentNextItems.contains(_)) + frequentNextItems.flatMap { item => + LocalPrefixSpan.getSuffix(item, filteredSuffix) match { + case suffix if !suffix.isEmpty => Some(item :: prefix, suffix) + case _ => None + } + } + } + + (extendedPrefixAndCounts, extendedPrefixAndSuffix) } /** - * calculate the patterns in local. + * Calculate the patterns in local. * @param minCount the absolute minimum count - * @param data patterns and projected sequences data data + * @param data prefixes and projected sequences data data * @return patterns */ private def getPatternsInLocal( minCount: Long, - data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { - data.flatMap { case (prefix, projDB) => - LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB) - .map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) } + data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = { + data.flatMap { + case (prefix, projDB) => + LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB) + .map { case (pattern: List[Int], count: Long) => + (pattern.reverse, count) + } } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 9f107c89f6d80..6dd2dc926acc5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -44,13 +44,6 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(sequences, 2).cache() - def compareResult( - expectedValue: Array[(Array[Int], Long)], - actualValue: Array[(Array[Int], Long)]): Boolean = { - expectedValue.map(x => (x._1.toSeq, x._2)).toSet == - actualValue.map(x => (x._1.toSeq, x._2)).toSet - } - val prefixspan = new PrefixSpan() .setMinSupport(0.33) .setMaxPatternLength(50) @@ -76,7 +69,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4, 5), 2L), (Array(5), 3L) ) - assert(compareResult(expectedValue1, result1.collect())) + assert(compareResults(expectedValue1, result1.collect())) prefixspan.setMinSupport(0.5).setMaxPatternLength(50) val result2 = prefixspan.run(rdd) @@ -87,7 +80,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4), 4L), (Array(5), 3L) ) - assert(compareResult(expectedValue2, result2.collect())) + assert(compareResults(expectedValue2, result2.collect())) prefixspan.setMinSupport(0.33).setMaxPatternLength(2) val result3 = prefixspan.run(rdd) @@ -107,6 +100,14 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4, 5), 2L), (Array(5), 3L) ) - assert(compareResult(expectedValue3, result3.collect())) + assert(compareResults(expectedValue3, result3.collect())) + } + + private def compareResults( + expectedValue: Array[(Array[Int], Long)], + actualValue: Array[(Array[Int], Long)]): Boolean = { + expectedValue.map(x => (x._1.toSeq, x._2)).toSet == + actualValue.map(x => (x._1.toSeq, x._2)).toSet } + } From 9c0501c5d04d83ca25ce433138bf64df6a14dc58 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 30 Jul 2015 08:20:52 -0700 Subject: [PATCH 20/50] [SPARK-] [MLLIB] minor fix on tokenizer doc A trivial fix for the comments of RegexTokenizer. Maybe this is too small, yet I just noticed it and think it can be quite misleading. I can create a jira if necessary. Author: Yuhao Yang Closes #7791 from hhbyyh/docFix and squashes the following commits: cdf2542 [Yuhao Yang] minor fix on tokenizer doc --- .../src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 0b3af4747e693..248288ca73e99 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -50,7 +50,7 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S /** * :: Experimental :: * A regex based tokenizer that extracts tokens either by using the provided regex pattern to split - * the text (default) or repeatedly matching the regex (if `gaps` is true). + * the text (default) or repeatedly matching the regex (if `gaps` is false). * Optional parameters also allow filtering tokens using a minimal length. * It returns an array of strings that can be empty. */ From a6e53a9c8b24326d1b6dca7a0e36ce6c643daa77 Mon Sep 17 00:00:00 2001 From: Meihua Wu Date: Thu, 30 Jul 2015 08:52:01 -0700 Subject: [PATCH 21/50] [SPARK-9225] [MLLIB] LDASuite needs unit tests for empty documents Add unit tests for running LDA with empty documents. Both EMLDAOptimizer and OnlineLDAOptimizer are tested. feynmanliang Author: Meihua Wu Closes #7620 from rotationsymmetry/SPARK-9225 and squashes the following commits: 3ed7c88 [Meihua Wu] Incorporate reviewer's further comments f9432e8 [Meihua Wu] Incorporate reviewer's comments 8e1b9ec [Meihua Wu] Merge remote-tracking branch 'upstream/master' into SPARK-9225 ad55665 [Meihua Wu] Add unit tests for running LDA with empty documents --- .../spark/mllib/clustering/LDASuite.scala | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index b91c7cefed22e..61d2edfd9fb5f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -390,6 +390,46 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("EMLDAOptimizer with empty docs") { + val vocabSize = 6 + val emptyDocsArray = Array.fill(6)(Vectors.sparse(vocabSize, Array.empty, Array.empty)) + val emptyDocs = emptyDocsArray + .zipWithIndex.map { case (wordCounts, docId) => + (docId.toLong, wordCounts) + } + val distributedEmptyDocs = sc.parallelize(emptyDocs, 2) + + val op = new EMLDAOptimizer() + val lda = new LDA() + .setK(3) + .setMaxIterations(5) + .setSeed(12345) + .setOptimizer(op) + + val model = lda.run(distributedEmptyDocs) + assert(model.vocabSize === vocabSize) + } + + test("OnlineLDAOptimizer with empty docs") { + val vocabSize = 6 + val emptyDocsArray = Array.fill(6)(Vectors.sparse(vocabSize, Array.empty, Array.empty)) + val emptyDocs = emptyDocsArray + .zipWithIndex.map { case (wordCounts, docId) => + (docId.toLong, wordCounts) + } + val distributedEmptyDocs = sc.parallelize(emptyDocs, 2) + + val op = new OnlineLDAOptimizer() + val lda = new LDA() + .setK(3) + .setMaxIterations(5) + .setSeed(12345) + .setOptimizer(op) + + val model = lda.run(distributedEmptyDocs) + assert(model.vocabSize === vocabSize) + } + } private[clustering] object LDASuite { From ed3cb1d21c73645c8f6e6ee08181f876fc192e41 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 30 Jul 2015 09:19:55 -0700 Subject: [PATCH 22/50] [SPARK-9277] [MLLIB] SparseVector constructor must throw an error when declared number of elements less than array length Check that SparseVector size is at least as big as the number of indices/values provided. And add tests for constructor checks. CC MechCoder jkbradley -- I am not sure if a change needs to also happen in the Python API? I didn't see it had any similar checks to begin with, but I don't know it well. Author: Sean Owen Closes #7794 from srowen/SPARK-9277 and squashes the following commits: e8dc31e [Sean Owen] Fix scalastyle 6ffe34a [Sean Owen] Check that SparseVector size is at least as big as the number of indices/values provided. And add tests for constructor checks. --- .../org/apache/spark/mllib/linalg/Vectors.scala | 2 ++ .../apache/spark/mllib/linalg/VectorsSuite.scala | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 0cb28d78bec05..23c2c16d68d9a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -637,6 +637,8 @@ class SparseVector( require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + s" ${values.length} values.") + require(indices.length <= size, s"You provided ${indices.length} indices and values, " + + s"which exceeds the specified vector size ${size}.") override def toString: String = s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 03be4119bdaca..1c37ea5123e82 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -57,6 +57,21 @@ class VectorsSuite extends SparkFunSuite with Logging { assert(vec.values === values) } + test("sparse vector construction with mismatched indices/values array") { + intercept[IllegalArgumentException] { + Vectors.sparse(4, Array(1, 2, 3), Array(3.0, 5.0, 7.0, 9.0)) + } + intercept[IllegalArgumentException] { + Vectors.sparse(4, Array(1, 2, 3), Array(3.0, 5.0)) + } + } + + test("sparse vector construction with too many indices vs size") { + intercept[IllegalArgumentException] { + Vectors.sparse(3, Array(1, 2, 3, 4), Array(3.0, 5.0, 7.0, 9.0)) + } + } + test("dense to array") { val vec = Vectors.dense(arr).asInstanceOf[DenseVector] assert(vec.toArray.eq(arr)) From 81464f2a8243c6ae2a39bac7ebdc50d4f60af451 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 30 Jul 2015 09:45:17 -0700 Subject: [PATCH 23/50] [MINOR] [MLLIB] fix doc for RegexTokenizer This is #7791 for Python. hhbyyh Author: Xiangrui Meng Closes #7798 from mengxr/regex-tok-py and squashes the following commits: baa2dcd [Xiangrui Meng] fix doc for RegexTokenizer --- python/pyspark/ml/feature.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 86e654dd0779f..015e7a9d4900a 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -525,7 +525,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ A regex based tokenizer that extracts tokens either by using the provided regex pattern (in Java dialect) to split the text - (default) or repeatedly matching the regex (if gaps is true). + (default) or repeatedly matching the regex (if gaps is false). Optional parameters also allow filtering tokens using a minimal length. It returns an array of strings that can be empty. From 7492a33fdd074446c30c657d771a69932a00246d Mon Sep 17 00:00:00 2001 From: Yuu ISHIKAWA Date: Thu, 30 Jul 2015 10:00:27 -0700 Subject: [PATCH 24/50] [SPARK-9248] [SPARKR] Closing curly-braces should always be on their own line ### JIRA [[SPARK-9248] Closing curly-braces should always be on their own line - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-9248) ## The result of `dev/lint-r` [The result of `dev/lint-r` for SPARK-9248 at the revistion:6175d6cfe795fbd88e3ee713fac375038a3993a8](https://gist.github.com/yu-iskw/96cadcea4ce664c41f81) Author: Yuu ISHIKAWA Closes #7795 from yu-iskw/SPARK-9248 and squashes the following commits: c8eccd3 [Yuu ISHIKAWA] [SPARK-9248][SparkR] Closing curly-braces should always be on their own line --- R/pkg/R/generics.R | 14 +++++++------- R/pkg/R/pairRDD.R | 4 ++-- R/pkg/R/sparkR.R | 9 ++++++--- R/pkg/inst/tests/test_sparkSQL.R | 6 ++++-- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 836e0175c391f..a3a121058e165 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -254,8 +254,10 @@ setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") # @rdname intersection # @export -setGeneric("intersection", function(x, other, numPartitions = 1) { - standardGeneric("intersection") }) +setGeneric("intersection", + function(x, other, numPartitions = 1) { + standardGeneric("intersection") + }) # @rdname keys # @export @@ -489,9 +491,7 @@ setGeneric("sample", #' @rdname sample #' @export setGeneric("sample_frac", - function(x, withReplacement, fraction, seed) { - standardGeneric("sample_frac") - }) + function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") }) #' @rdname saveAsParquetFile #' @export @@ -553,8 +553,8 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn #' @rdname withColumnRenamed #' @export -setGeneric("withColumnRenamed", function(x, existingCol, newCol) { - standardGeneric("withColumnRenamed") }) +setGeneric("withColumnRenamed", + function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) ###################### Column Methods ########################## diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index ebc6ff65e9d0f..83801d3209700 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -202,8 +202,8 @@ setMethod("partitionBy", packageNamesArr <- serialize(.sparkREnv$.packages, connection = NULL) - broadcastArr <- lapply(ls(.broadcastNames), function(name) { - get(name, .broadcastNames) }) + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) jrdd <- getJRDD(x) # We create a PairwiseRRDD that extends RDD[(Int, Array[Byte])], diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 76c15875b50d5..e83104f116422 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -22,7 +22,8 @@ connExists <- function(env) { tryCatch({ exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]]) - }, error = function(err) { + }, + error = function(err) { return(FALSE) }) } @@ -153,7 +154,8 @@ sparkR.init <- function( .sparkREnv$backendPort <- backendPort tryCatch({ connectBackend("localhost", backendPort) - }, error = function(err) { + }, + error = function(err) { stop("Failed to connect JVM\n") }) @@ -264,7 +266,8 @@ sparkRHive.init <- function(jsc = NULL) { ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.HiveContext", ssc) - }, error = function(err) { + }, + error = function(err) { stop("Spark SQL is not built with Hive support") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 62fe48a5d6c7b..d5db97248c770 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -112,7 +112,8 @@ test_that("create DataFrame from RDD", { df <- jsonFile(sqlContext, jsonPathNa) hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, error = function(err) { + }, + error = function(err) { skip("Hive is not build with SparkSQL, skipped") }) sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") @@ -602,7 +603,8 @@ test_that("write.df() as parquet file", { test_that("test HiveContext", { hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, error = function(err) { + }, + error = function(err) { skip("Hive is not build with SparkSQL, skipped") }) df <- createExternalTable(hiveCtx, "json", jsonPath, "json") From c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 30 Jul 2015 10:04:30 -0700 Subject: [PATCH 25/50] [SPARK-9390][SQL] create a wrapper for array type Author: Wenchen Fan Closes #7724 from cloud-fan/array-data and squashes the following commits: d0408a1 [Wenchen Fan] fix python 661e608 [Wenchen Fan] rebase f39256c [Wenchen Fan] fix hive... 6dbfa6f [Wenchen Fan] fix hive again... 8cb8842 [Wenchen Fan] remove element type parameter from getArray 43e9816 [Wenchen Fan] fix mllib e719afc [Wenchen Fan] fix hive 4346290 [Wenchen Fan] address comment d4a38da [Wenchen Fan] remove sizeInBytes and add license 7e283e2 [Wenchen Fan] create a wrapper for array type --- .../apache/spark/mllib/linalg/Matrices.scala | 16 +-- .../apache/spark/mllib/linalg/Vectors.scala | 15 +-- .../expressions/SpecializedGetters.java | 2 + .../sql/catalyst/CatalystTypeConverters.scala | 29 +++-- .../spark/sql/catalyst/InternalRow.scala | 2 + .../catalyst/expressions/BoundAttribute.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 39 ++++-- .../expressions/codegen/CodeGenerator.scala | 28 ++-- .../codegen/GenerateUnsafeProjection.scala | 4 +- .../expressions/collectionOperations.scala | 10 +- .../expressions/complexTypeCreator.scala | 20 ++- .../expressions/complexTypeExtractors.scala | 59 ++++++--- .../sql/catalyst/expressions/generators.scala | 4 +- .../expressions/stringOperations.scala | 12 +- .../sql/catalyst/optimizer/Optimizer.scala | 3 +- .../apache/spark/sql/types/ArrayData.scala | 121 ++++++++++++++++++ .../spark/sql/types/GenericArrayData.scala | 59 +++++++++ .../sql/catalyst/expressions/CastSuite.scala | 21 ++- .../expressions/ComplexTypeSuite.scala | 2 +- .../spark/sql/execution/debug/package.scala | 4 +- .../spark/sql/execution/pythonUDFs.scala | 19 ++- .../sql/execution/stat/FrequentItems.scala | 4 +- .../apache/spark/sql/json/InferSchema.scala | 2 +- .../apache/spark/sql/json/JacksonParser.scala | 30 +++-- .../sql/parquet/CatalystRowConverter.scala | 2 +- .../spark/sql/parquet/ParquetConverter.scala | 3 +- .../sql/parquet/ParquetTableSupport.scala | 12 +- .../apache/spark/sql/JavaDataFrameSuite.java | 5 +- .../spark/sql/UserDefinedTypeSuite.scala | 8 +- .../spark/sql/sources/TableScanSuite.scala | 30 ++--- .../spark/sql/hive/HiveInspectors.scala | 28 ++-- .../hive/execution/ScriptTransformation.scala | 12 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 2 +- .../spark/sql/hive/HiveInspectorSuite.scala | 2 +- 34 files changed, 430 insertions(+), 181 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index d82ba2456df1a..88914fa875990 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -154,9 +154,9 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { row.setByte(0, 0) row.setInt(1, sm.numRows) row.setInt(2, sm.numCols) - row.update(3, sm.colPtrs.toSeq) - row.update(4, sm.rowIndices.toSeq) - row.update(5, sm.values.toSeq) + row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any]))) + row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any]))) + row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any]))) row.setBoolean(6, sm.isTransposed) case dm: DenseMatrix => @@ -165,7 +165,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { row.setInt(2, dm.numCols) row.setNullAt(3) row.setNullAt(4) - row.update(5, dm.values.toSeq) + row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any]))) row.setBoolean(6, dm.isTransposed) } row @@ -179,14 +179,12 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { val tpe = row.getByte(0) val numRows = row.getInt(1) val numCols = row.getInt(2) - val values = row.getAs[Seq[Double]](5, ArrayType(DoubleType, containsNull = false)).toArray + val values = row.getArray(5).toArray.map(_.asInstanceOf[Double]) val isTransposed = row.getBoolean(6) tpe match { case 0 => - val colPtrs = - row.getAs[Seq[Int]](3, ArrayType(IntegerType, containsNull = false)).toArray - val rowIndices = - row.getAs[Seq[Int]](4, ArrayType(IntegerType, containsNull = false)).toArray + val colPtrs = row.getArray(3).toArray.map(_.asInstanceOf[Int]) + val rowIndices = row.getArray(4).toArray.map(_.asInstanceOf[Int]) new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) case 1 => new DenseMatrix(numRows, numCols, values, isTransposed) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 23c2c16d68d9a..89a1818db0d1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -187,15 +187,15 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { val row = new GenericMutableRow(4) row.setByte(0, 0) row.setInt(1, size) - row.update(2, indices.toSeq) - row.update(3, values.toSeq) + row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any]))) + row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) row case DenseVector(values) => val row = new GenericMutableRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) - row.update(3, values.toSeq) + row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) row } } @@ -209,14 +209,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { tpe match { case 0 => val size = row.getInt(1) - val indices = - row.getAs[Seq[Int]](2, ArrayType(IntegerType, containsNull = false)).toArray - val values = - row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray + val indices = row.getArray(2).toArray().map(_.asInstanceOf[Int]) + val values = row.getArray(3).toArray().map(_.asInstanceOf[Double]) new SparseVector(size, indices, values) case 1 => - val values = - row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray + val values = row.getArray(3).toArray().map(_.asInstanceOf[Double]) new DenseVector(values) } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java index bc345dcd00e49..f7cea13688876 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.ArrayData; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -50,4 +51,5 @@ public interface SpecializedGetters { InternalRow getStruct(int ordinal, int numFields); + ArrayData getArray(int ordinal); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index d1d89a1f48329..22452c0f201ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -55,7 +55,6 @@ object CatalystTypeConverters { private def isWholePrimitive(dt: DataType): Boolean = dt match { case dt if isPrimitive(dt) => true - case ArrayType(elementType, _) => isWholePrimitive(elementType) case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType) case _ => false } @@ -154,39 +153,41 @@ object CatalystTypeConverters { /** Converter for arrays, sequences, and Java iterables. */ private case class ArrayConverter( - elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] { + elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], ArrayData] { private[this] val elementConverter = getConverterForType(elementType) private[this] val isNoChange = isWholePrimitive(elementType) - override def toCatalystImpl(scalaValue: Any): Seq[Any] = { + override def toCatalystImpl(scalaValue: Any): ArrayData = { scalaValue match { - case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst) - case s: Seq[_] => s.map(elementConverter.toCatalyst) + case a: Array[_] => + new GenericArrayData(a.map(elementConverter.toCatalyst)) + case s: Seq[_] => + new GenericArrayData(s.map(elementConverter.toCatalyst).toArray) case i: JavaIterable[_] => val iter = i.iterator - var convertedIterable: List[Any] = List() + val convertedIterable = scala.collection.mutable.ArrayBuffer.empty[Any] while (iter.hasNext) { val item = iter.next() - convertedIterable :+= elementConverter.toCatalyst(item) + convertedIterable += elementConverter.toCatalyst(item) } - convertedIterable + new GenericArrayData(convertedIterable.toArray) } } - override def toScala(catalystValue: Seq[Any]): Seq[Any] = { + override def toScala(catalystValue: ArrayData): Seq[Any] = { if (catalystValue == null) { null } else if (isNoChange) { - catalystValue + catalystValue.toArray() } else { - catalystValue.map(elementConverter.toScala) + catalystValue.toArray().map(elementConverter.toScala) } } override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] = - toScala(row.get(column, ArrayType(elementType)).asInstanceOf[Seq[Any]]) + toScala(row.getArray(column)) } private case class MapConverter( @@ -402,9 +403,9 @@ object CatalystTypeConverters { case t: Timestamp => TimestampConverter.toCatalyst(t) case d: BigDecimal => BigDecimalConverter.toCatalyst(d) case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d) - case seq: Seq[Any] => seq.map(convertToCatalyst) + case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray) case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) - case arr: Array[Any] => arr.map(convertToCatalyst) + case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst)) case m: Map[_, _] => m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index a5999e64ec554..486ba036548c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -76,6 +76,8 @@ abstract class InternalRow extends Serializable with SpecializedGetters { override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs[InternalRow](ordinal, null) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal, null) + override def toString: String = s"[${this.mkString(",")}]" /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 371681b5d494f..45709c1c8f554 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -65,7 +65,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) - val value = ctx.getColumn("i", dataType, ordinal) + val value = ctx.getValue("i", dataType, ordinal.toString) s""" boolean ${ev.isNull} = i.isNullAt($ordinal); $javaType ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8c01c13c9ccd5..43be11c48ae7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -363,7 +363,21 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castArray(from: ArrayType, to: ArrayType): Any => Any = { val elementCast = cast(from.elementType, to.elementType) - buildCast[Seq[Any]](_, _.map(v => if (v == null) null else elementCast(v))) + // TODO: Could be faster? + buildCast[ArrayData](_, array => { + val length = array.numElements() + val values = new Array[Any](length) + var i = 0 + while (i < length) { + if (array.isNullAt(i)) { + values(i) = null + } else { + values(i) = elementCast(array.get(i)) + } + i += 1 + } + new GenericArrayData(values) + }) } private[this] def castMap(from: MapType, to: MapType): Any => Any = { @@ -789,37 +803,36 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castArrayCode( from: ArrayType, to: ArrayType, ctx: CodeGenContext): CastFunction = { val elementCast = nullSafeCastFunction(from.elementType, to.elementType, ctx) - - val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName + val arrayClass = classOf[GenericArrayData].getName val fromElementNull = ctx.freshName("feNull") val fromElementPrim = ctx.freshName("fePrim") val toElementNull = ctx.freshName("teNull") val toElementPrim = ctx.freshName("tePrim") val size = ctx.freshName("n") val j = ctx.freshName("j") - val result = ctx.freshName("result") + val values = ctx.freshName("values") (c, evPrim, evNull) => s""" - final int $size = $c.size(); - final $arraySeqClass $result = new $arraySeqClass($size); + final int $size = $c.numElements(); + final Object[] $values = new Object[$size]; for (int $j = 0; $j < $size; $j ++) { - if ($c.apply($j) == null) { - $result.update($j, null); + if ($c.isNullAt($j)) { + $values[$j] = null; } else { boolean $fromElementNull = false; ${ctx.javaType(from.elementType)} $fromElementPrim = - (${ctx.boxedType(from.elementType)}) $c.apply($j); + ${ctx.getValue(c, from.elementType, j)}; ${castCode(ctx, fromElementPrim, fromElementNull, toElementPrim, toElementNull, to.elementType, elementCast)} if ($toElementNull) { - $result.update($j, null); + $values[$j] = null; } else { - $result.update($j, $toElementPrim); + $values[$j] = $toElementPrim; } } } - $evPrim = $result; + $evPrim = new $arrayClass($values); """ } @@ -891,7 +904,7 @@ case class Cast(child: Expression, dataType: DataType) $result.setNullAt($i); } else { $fromType $fromFieldPrim = - ${ctx.getColumn(tmpRow, from.fields(i).dataType, i)}; + ${ctx.getValue(tmpRow, from.fields(i).dataType, i.toString)}; ${castCode(ctx, fromFieldPrim, fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)} if ($toFieldNull) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 092f4c9fb0bd2..c39e0df6fae2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -100,17 +100,18 @@ class CodeGenContext { } /** - * Returns the code to access a column in Row for a given DataType. + * Returns the code to access a value in `SpecializedGetters` for a given DataType. */ - def getColumn(row: String, dataType: DataType, ordinal: Int): String = { + def getValue(getter: String, dataType: DataType, ordinal: String): String = { val jt = javaType(dataType) dataType match { - case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)" - case StringType => s"$row.getUTF8String($ordinal)" - case BinaryType => s"$row.getBinary($ordinal)" - case CalendarIntervalType => s"$row.getInterval($ordinal)" - case t: StructType => s"$row.getStruct($ordinal, ${t.size})" - case _ => s"($jt)$row.get($ordinal)" + case _ if isPrimitiveType(jt) => s"$getter.get${primitiveTypeName(jt)}($ordinal)" + case StringType => s"$getter.getUTF8String($ordinal)" + case BinaryType => s"$getter.getBinary($ordinal)" + case CalendarIntervalType => s"$getter.getInterval($ordinal)" + case t: StructType => s"$getter.getStruct($ordinal, ${t.size})" + case a: ArrayType => s"$getter.getArray($ordinal)" + case _ => s"($jt)$getter.get($ordinal)" // todo: remove generic getter. } } @@ -152,8 +153,8 @@ class CodeGenContext { case StringType => "UTF8String" case CalendarIntervalType => "CalendarInterval" case _: StructType => "InternalRow" - case _: ArrayType => s"scala.collection.Seq" - case _: MapType => s"scala.collection.Map" + case _: ArrayType => "ArrayData" + case _: MapType => "scala.collection.Map" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" @@ -214,7 +215,9 @@ class CodeGenContext { case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" case NullType => "0" - case other => s"$c1.compare($c2)" + case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" + case _ => throw new IllegalArgumentException( + "cannot generate compare code for un-comparable type") } /** @@ -293,7 +296,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin classOf[UnsafeRow].getName, classOf[UTF8String].getName, classOf[Decimal].getName, - classOf[CalendarInterval].getName + classOf[CalendarInterval].getName, + classOf[ArrayData].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7be60114ce674..a662357fb6cf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -153,14 +153,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val nestedStructEv = GeneratedExpressionCode( code = "", isNull = s"${input.primitive}.isNullAt($i)", - primitive = s"${ctx.getColumn(input.primitive, dt, i)}" + primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" ) createCodeForStruct(ctx, nestedStructEv, st) case _ => GeneratedExpressionCode( code = "", isNull = s"${input.primitive}.isNullAt($i)", - primitive = s"${ctx.getColumn(input.primitive, dt, i)}" + primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2d92dcf23a86e..1a00dbc254de1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -27,11 +27,15 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) override def nullSafeEval(value: Any): Int = child.dataType match { - case ArrayType(_, _) => value.asInstanceOf[Seq[Any]].size - case MapType(_, _, _) => value.asInstanceOf[Map[Any, Any]].size + case _: ArrayType => value.asInstanceOf[ArrayData].numElements() + case _: MapType => value.asInstanceOf[Map[Any, Any]].size } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).size();") + val sizeCall = child.dataType match { + case _: ArrayType => "numElements()" + case _: MapType => "size()" + } + nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).$sizeCall;") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 0517050a45109..a145dfb4bbf08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,12 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.unsafe.types.UTF8String - -import scala.collection.mutable - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -46,25 +43,26 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def nullable: Boolean = false override def eval(input: InternalRow): Any = { - children.map(_.eval(input)) + new GenericArrayData(children.map(_.eval(input)).toArray) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName + val arrayClass = classOf[GenericArrayData].getName s""" - boolean ${ev.isNull} = false; - $arraySeqClass ${ev.primitive} = new $arraySeqClass(${children.size}); + final boolean ${ev.isNull} = false; + final Object[] values = new Object[${children.size}]; """ + children.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" if (${eval.isNull}) { - ${ev.primitive}.update($i, null); + values[$i] = null; } else { - ${ev.primitive}.update($i, ${eval.primitive}); + values[$i] = ${eval.primitive}; } """ - }.mkString("\n") + }.mkString("\n") + + s"final ${ctx.javaType(dataType)} ${ev.primitive} = new $arrayClass(values);" } override def prettyName: String = "array" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 6331a9eb603ca..99393c9c76ab6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -57,7 +57,8 @@ object ExtractValue { case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) - GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull) + GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), + ordinal, fields.length, containsNull) case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => GetArrayItem(child, extraction) @@ -118,7 +119,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) if ($eval.isNullAt($ordinal)) { ${ev.isNull} = true; } else { - ${ev.primitive} = ${ctx.getColumn(eval, dataType, ordinal)}; + ${ev.primitive} = ${ctx.getValue(eval, dataType, ordinal.toString)}; } """ }) @@ -134,6 +135,7 @@ case class GetArrayStructFields( child: Expression, field: StructField, ordinal: Int, + numFields: Int, containsNull: Boolean) extends UnaryExpression { override def dataType: DataType = ArrayType(field.dataType, containsNull) @@ -141,26 +143,45 @@ case class GetArrayStructFields( override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = { - input.asInstanceOf[Seq[InternalRow]].map { row => - if (row == null) null else row.get(ordinal, field.dataType) + val array = input.asInstanceOf[ArrayData] + val length = array.numElements() + val result = new Array[Any](length) + var i = 0 + while (i < length) { + if (array.isNullAt(i)) { + result(i) = null + } else { + val row = array.getStruct(i, numFields) + if (row.isNullAt(ordinal)) { + result(i) = null + } else { + result(i) = row.get(ordinal, field.dataType) + } + } + i += 1 } + new GenericArrayData(result) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val arraySeqClass = "scala.collection.mutable.ArraySeq" - // TODO: consider using Array[_] for ArrayType child to avoid - // boxing of primitives + val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, eval => { s""" - final int n = $eval.size(); - final $arraySeqClass values = new $arraySeqClass(n); + final int n = $eval.numElements(); + final Object[] values = new Object[n]; for (int j = 0; j < n; j++) { - InternalRow row = (InternalRow) $eval.apply(j); - if (row != null && !row.isNullAt($ordinal)) { - values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)}); + if ($eval.isNullAt(j)) { + values[j] = null; + } else { + final InternalRow row = $eval.getStruct(j, $numFields); + if (row.isNullAt($ordinal)) { + values[j] = null; + } else { + values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)}; + } } } - ${ev.primitive} = (${ctx.javaType(dataType)}) values; + ${ev.primitive} = new $arrayClass(values); """ }) } @@ -186,23 +207,23 @@ case class GetArrayItem(child: Expression, ordinal: Expression) extends BinaryEx protected override def nullSafeEval(value: Any, ordinal: Any): Any = { // TODO: consider using Array[_] for ArrayType child to avoid // boxing of primitives - val baseValue = value.asInstanceOf[Seq[_]] + val baseValue = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Number].intValue() - if (index >= baseValue.size || index < 0) { + if (index >= baseValue.numElements() || index < 0) { null } else { - baseValue(index) + baseValue.get(index) } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - final int index = (int)$eval2; - if (index >= $eval1.size() || index < 0) { + final int index = (int) $eval2; + if (index >= $eval1.numElements() || index < 0) { ${ev.isNull} = true; } else { - ${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply(index); + ${ev.primitive} = ${ctx.getValue(eval1, dataType, "index")}; } """ }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 2dbcf2830f876..8064235c64ef9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -121,8 +121,8 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit override def eval(input: InternalRow): TraversableOnce[InternalRow] = { child.dataType match { case ArrayType(_, _) => - val inputArray = child.eval(input).asInstanceOf[Seq[Any]] - if (inputArray == null) Nil else inputArray.map(v => InternalRow(v)) + val inputArray = child.eval(input).asInstanceOf[ArrayData] + if (inputArray == null) Nil else inputArray.toArray().map(v => InternalRow(v)) case MapType(_, _, _) => val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]] if (inputMap == null) Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 5b3a64a09679c..79c0ca56a8e79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -92,7 +92,7 @@ case class ConcatWs(children: Seq[Expression]) val flatInputs = children.flatMap { child => child.eval(input) match { case s: UTF8String => Iterator(s) - case arr: Seq[_] => arr.asInstanceOf[Seq[UTF8String]] + case arr: ArrayData => arr.toArray().map(_.asInstanceOf[UTF8String]) case null => Iterator(null.asInstanceOf[UTF8String]) } } @@ -105,7 +105,7 @@ case class ConcatWs(children: Seq[Expression]) val evals = children.map(_.gen(ctx)) val inputs = evals.map { eval => - s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}" + s"${eval.isNull} ? (UTF8String) null : ${eval.primitive}" }.mkString(", ") evals.map(_.code).mkString("\n") + s""" @@ -665,13 +665,15 @@ case class StringSplit(str: Expression, pattern: Expression) override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def nullSafeEval(string: Any, regex: Any): Any = { - string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1).toSeq + val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1) + new GenericArrayData(strings.asInstanceOf[Array[Any]]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, (str, pattern) => - s"""${ev.primitive} = scala.collection.JavaConversions.asScalaBuffer( - java.util.Arrays.asList($str.split($pattern, -1)));""") + // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. + s"""${ev.primitive} = new $arrayClass($str.split($pattern, -1));""") } override def prettyName: String = "split" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 813c62009666c..29d706dcb39a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -312,7 +312,8 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType) case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType) case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType) - case e @ GetArrayStructFields(Literal(null, _), _, _, _) => Literal.create(null, e.dataType) + case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) => + Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case e @ Count(expr) if !expr.nullable => Count(Literal(1)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala new file mode 100644 index 0000000000000..14a7285877622 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters + +abstract class ArrayData extends SpecializedGetters with Serializable { + // todo: remove this after we handle all types.(map type need special getter) + def get(ordinal: Int): Any + + def numElements(): Int + + // todo: need a more efficient way to iterate array type. + def toArray(): Array[Any] = { + val n = numElements() + val values = new Array[Any](n) + var i = 0 + while (i < n) { + if (isNullAt(i)) { + values(i) = null + } else { + values(i) = get(i) + } + i += 1 + } + values + } + + override def toString(): String = toArray.mkString("[", ",", "]") + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[ArrayData]) { + return false + } + + val other = o.asInstanceOf[ArrayData] + if (other eq null) { + return false + } + + val len = numElements() + if (len != other.numElements()) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = get(i) + val o2 = other.get(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + val len = numElements() + while (i < len) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + get(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala new file mode 100644 index 0000000000000..7992ba947c069 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.{UTF8String, CalendarInterval} + +class GenericArrayData(array: Array[Any]) extends ArrayData { + private def getAs[T](ordinal: Int) = get(ordinal).asInstanceOf[T] + + override def toArray(): Array[Any] = array + + override def get(ordinal: Int): Any = array(ordinal) + + override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null + + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + + override def getByte(ordinal: Int): Byte = getAs(ordinal) + + override def getShort(ordinal: Int): Short = getAs(ordinal) + + override def getInt(ordinal: Int): Int = getAs(ordinal) + + override def getLong(ordinal: Int): Long = getAs(ordinal) + + override def getFloat(ordinal: Int): Float = getAs(ordinal) + + override def getDouble(ordinal: Int): Double = getAs(ordinal) + + override def getDecimal(ordinal: Int): Decimal = getAs(ordinal) + + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + + override def numElements(): Int = array.length +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index a517da9872852..4f35b653d73c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Timestamp, Date} import java.util.{TimeZone, Calendar} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -730,13 +731,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("complex casting") { val complex = Literal.create( - InternalRow( - Seq(UTF8String.fromString("123"), UTF8String.fromString("abc"), UTF8String.fromString("")), - Map( - UTF8String.fromString("a") -> UTF8String.fromString("123"), - UTF8String.fromString("b") -> UTF8String.fromString("abc"), - UTF8String.fromString("c") -> UTF8String.fromString("")), - InternalRow(0)), + Row( + Seq("123", "abc", ""), + Map("a" ->"123", "b" -> "abc", "c" -> ""), + Row(0)), StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false), nullable = true), @@ -756,13 +754,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("l", LongType, nullable = true))))))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow( + checkEvaluation(ret, Row( Seq(123, null, null), - Map( - UTF8String.fromString("a") -> true, - UTF8String.fromString("b") -> true, - UTF8String.fromString("c") -> false), - InternalRow(0L))) + Map("a" -> true, "b" -> true, "c" -> false), + Row(0L))) } test("case between string and interval") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 5de5ddce975d8..3fa246b69d1f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -110,7 +110,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { expr.dataType match { case ArrayType(StructType(fields), containsNull) => val field = fields.find(_.name == fieldName).get - GetArrayStructFields(expr, field, fields.indexOf(field), containsNull) + GetArrayStructFields(expr, field, fields.indexOf(field), fields.length, containsNull) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index aeeb0e45270dd..f26f41fb75d57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -158,8 +158,8 @@ package object debug { case (row: InternalRow, StructType(fields)) => row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } - case (s: Seq[_], ArrayType(elemType, _)) => - s.foreach(typeCheck(_, elemType)) + case (a: ArrayData, ArrayType(elemType, _)) => + a.toArray().foreach(typeCheck(_, elemType)) case (m: Map[_, _], MapType(keyType, valueType, _)) => m.keys.foreach(typeCheck(_, keyType)) m.values.foreach(typeCheck(_, valueType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 3c38916fd7504..ef1c6e57dc08a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -134,8 +134,19 @@ object EvaluatePython { } new GenericInternalRowWithSchema(values, struct) - case (seq: Seq[Any], array: ArrayType) => - seq.map(x => toJava(x, array.elementType)).asJava + case (a: ArrayData, array: ArrayType) => + val length = a.numElements() + val values = new java.util.ArrayList[Any](length) + var i = 0 + while (i < length) { + if (a.isNullAt(i)) { + values.add(null) + } else { + values.add(toJava(a.get(i), array.elementType)) + } + i += 1 + } + values case (obj: Map[_, _], mt: MapType) => obj.map { case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType)) @@ -190,10 +201,10 @@ object EvaluatePython { case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c case (c: java.util.List[_], ArrayType(elementType, _)) => - c.map { e => fromJava(e, elementType)}.toSeq + new GenericArrayData(c.map { e => fromJava(e, elementType)}.toArray) case (c, ArrayType(elementType, _)) if c.getClass.isArray => - c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)).toSeq + new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { case (key, value) => (fromJava(key, keyType), fromJava(value, valueType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 78da2840dad69..9329148aa233c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.{DataType, ArrayType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame} private[sql] object FrequentItems extends Logging { @@ -110,7 +110,7 @@ private[sql] object FrequentItems extends Logging { baseCounts } ) - val justItems = freqItems.map(m => m.baseMap.keys.toSeq) + val justItems = freqItems.map(m => m.baseMap.keys.toArray).map(new GenericArrayData(_)) val resultRow = InternalRow(justItems : _*) // append frequent Items to the column name for easy debugging val outputCols = colInfo.map { v => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index 0eb3b04007f8d..04ab5e2217882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -125,7 +125,7 @@ private[sql] object InferSchema { * Convert NullType to StringType and remove StructTypes with no fields */ private def canonicalizeType: DataType => Option[DataType] = { - case at@ArrayType(elementType, _) => + case at @ ArrayType(elementType, _) => for { canonicalType <- canonicalizeType(elementType) } yield { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala index 381e7ed54428f..1c309f8794ef3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -110,8 +110,13 @@ private[sql] object JacksonParser { case (START_OBJECT, st: StructType) => convertObject(factory, parser, st) + case (START_ARRAY, st: StructType) => + // SPARK-3308: support reading top level JSON arrays and take every element + // in such an array as a row + convertArray(factory, parser, st) + case (START_ARRAY, ArrayType(st, _)) => - convertList(factory, parser, st) + convertArray(factory, parser, st) case (START_OBJECT, ArrayType(st, _)) => // the business end of SPARK-3308: @@ -165,16 +170,16 @@ private[sql] object JacksonParser { builder.result() } - private def convertList( + private def convertArray( factory: JsonFactory, parser: JsonParser, - schema: DataType): Seq[Any] = { - val builder = Seq.newBuilder[Any] + elementType: DataType): ArrayData = { + val values = scala.collection.mutable.ArrayBuffer.empty[Any] while (nextUntil(parser, JsonToken.END_ARRAY)) { - builder += convertField(factory, parser, schema) + values += convertField(factory, parser, elementType) } - builder.result() + new GenericArrayData(values.toArray) } private def parseJson( @@ -201,12 +206,15 @@ private[sql] object JacksonParser { val parser = factory.createParser(record) parser.nextToken() - // to support both object and arrays (see SPARK-3308) we'll start - // by converting the StructType schema to an ArrayType and let - // convertField wrap an object into a single value array when necessary. - convertField(factory, parser, ArrayType(schema)) match { + convertField(factory, parser, schema) match { case null => failedRecord(record) - case list: Seq[InternalRow @unchecked] => list + case row: InternalRow => row :: Nil + case array: ArrayData => + if (array.numElements() == 0) { + Nil + } else { + array.toArray().map(_.asInstanceOf[InternalRow]) + } case _ => sys.error( s"Failed to parse record $record. Please make sure that each line of the file " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala index e00bd90edb3dd..172db8362afb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -325,7 +325,7 @@ private[parquet] class CatalystRowConverter( override def getConverter(fieldIndex: Int): Converter = elementConverter - override def end(): Unit = updater.set(currentArray) + override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) // NOTE: We can't reuse the mutable `ArrayBuffer` here and must instantiate a new buffer for the // next value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index ea51650fe9039..2332a36468dbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.parquet import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.ArrayData // TODO Removes this while fixing SPARK-8848 private[sql] object CatalystConverter { @@ -32,7 +33,7 @@ private[sql] object CatalystConverter { val MAP_SCHEMA_NAME = "map" // TODO: consider using Array[T] for arrays to avoid boxing of primitive types - type ArrayScalaType[T] = Seq[T] + type ArrayScalaType[T] = ArrayData type StructScalaType[T] = InternalRow type MapScalaType[K, V] = Map[K, V] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 78ecfad1d57c6..79dd16b7b0c39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -146,15 +146,15 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo array: CatalystConverter.ArrayScalaType[_]): Unit = { val elementType = schema.elementType writer.startGroup() - if (array.size > 0) { + if (array.numElements() > 0) { if (schema.containsNull) { writer.startField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0) var i = 0 - while (i < array.size) { + while (i < array.numElements()) { writer.startGroup() - if (array(i) != null) { + if (!array.isNullAt(i)) { writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - writeValue(elementType, array(i)) + writeValue(elementType, array.get(i)) writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) } writer.endGroup() @@ -164,8 +164,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo } else { writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) var i = 0 - while (i < array.size) { - writeValue(elementType, array(i)) + while (i < array.numElements()) { + writeValue(elementType, array.get(i)) i = i + 1 } writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 72c42f4fe376b..9e61d06f4036e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -30,7 +30,6 @@ import scala.collection.JavaConversions; import scala.collection.Seq; -import scala.collection.mutable.Buffer; import java.io.Serializable; import java.util.Arrays; @@ -168,10 +167,10 @@ public void testCreateDataFrameFromJavaBeans() { for (int i = 0; i < result.length(); i++) { Assert.assertEquals(bean.getB()[i], result.apply(i)); } - Buffer outputBuffer = (Buffer) first.getJavaMap(2).get("hello"); + Seq outputBuffer = (Seq) first.getJavaMap(2).get("hello"); Assert.assertArrayEquals( bean.getC().get("hello"), - Ints.toArray(JavaConversions.bufferAsJavaList(outputBuffer))); + Ints.toArray(JavaConversions.seqAsJavaList(outputBuffer))); Seq d = first.getAs(3); Assert.assertEquals(bean.getD().size(), d.length()); for (int i = 0; i < d.length(); i++) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 45c9f06941c10..77ed4a9c0d5ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -47,17 +47,17 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - override def serialize(obj: Any): Seq[Double] = { + override def serialize(obj: Any): ArrayData = { obj match { case features: MyDenseVector => - features.data.toSeq + new GenericArrayData(features.data.map(_.asInstanceOf[Any])) } } override def deserialize(datum: Any): MyDenseVector = { datum match { - case data: Seq[_] => - new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray) + case data: ArrayData => + new MyDenseVector(data.toArray.map(_.asInstanceOf[Double])) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 5e189c3563ca8..cfb03ff485b7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -67,12 +67,12 @@ case class AllDataTypesScan( override def schema: StructType = userSpecifiedSchema - override def needConversion: Boolean = false + override def needConversion: Boolean = true override def buildScan(): RDD[Row] = { sqlContext.sparkContext.parallelize(from to to).map { i => - InternalRow( - UTF8String.fromString(s"str_$i"), + Row( + s"str_$i", s"str_$i".getBytes(), i % 2 == 0, i.toByte, @@ -81,19 +81,19 @@ case class AllDataTypesScan( i.toLong, i.toFloat, i.toDouble, - Decimal(new java.math.BigDecimal(i)), - Decimal(new java.math.BigDecimal(i)), - DateTimeUtils.fromJavaDate(new Date(1970, 1, 1)), - DateTimeUtils.fromJavaTimestamp(new Timestamp(20000 + i)), - UTF8String.fromString(s"varchar_$i"), + new java.math.BigDecimal(i), + new java.math.BigDecimal(i), + new Date(1970, 1, 1), + new Timestamp(20000 + i), + s"varchar_$i", Seq(i, i + 1), - Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))), - Map(i -> UTF8String.fromString(i.toString)), - Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)), - InternalRow(i, UTF8String.fromString(i.toString)), - InternalRow(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")), - InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1)))))) - }.asInstanceOf[RDD[Row]] + Seq(Map(s"str_$i" -> Row(i.toLong))), + Map(i -> i.toString), + Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), + Row(i, i.toString), + Row(Seq(s"str_$i", s"str_${i + 1}"), + Row(Seq(new Date(1970, 1, i + 1))))) + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index f467500259c91..5926ef9aa388b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -52,9 +52,8 @@ import scala.collection.JavaConversions._ * java.sql.Timestamp * Complex Types => * Map: scala.collection.immutable.Map - * List: scala.collection.immutable.Seq - * Struct: - * [[org.apache.spark.sql.catalyst.InternalRow]] + * List: [[org.apache.spark.sql.types.ArrayData]] + * Struct: [[org.apache.spark.sql.catalyst.InternalRow]] * Union: NOT SUPPORTED YET * The Complex types plays as a container, which can hold arbitrary data types. * @@ -297,7 +296,10 @@ private[hive] trait HiveInspectors { }.toMap case li: StandardConstantListObjectInspector => // take the value from the list inspector object, rather than the input data - li.getWritableConstantValue.map(unwrap(_, li.getListElementObjectInspector)).toSeq + val values = li.getWritableConstantValue + .map(unwrap(_, li.getListElementObjectInspector)) + .toArray + new GenericArrayData(values) // if the value is null, we don't care about the object inspector type case _ if data == null => null case poi: VoidObjectInspector => null // always be null for void object inspector @@ -339,7 +341,10 @@ private[hive] trait HiveInspectors { } case li: ListObjectInspector => Option(li.getList(data)) - .map(_.map(unwrap(_, li.getListElementObjectInspector)).toSeq) + .map { l => + val values = l.map(unwrap(_, li.getListElementObjectInspector)).toArray + new GenericArrayData(values) + } .orNull case mi: MapObjectInspector => Option(mi.getMap(data)).map( @@ -391,7 +396,13 @@ private[hive] trait HiveInspectors { case loi: ListObjectInspector => val wrapper = wrapperFor(loi.getListElementObjectInspector) - (o: Any) => if (o != null) seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) else null + (o: Any) => { + if (o != null) { + seqAsJavaList(o.asInstanceOf[ArrayData].toArray().map(wrapper)) + } else { + null + } + } case moi: MapObjectInspector => // The Predef.Map is scala.collection.immutable.Map. @@ -520,7 +531,7 @@ private[hive] trait HiveInspectors { case x: ListObjectInspector => val list = new java.util.ArrayList[Object] val tpe = dataType.asInstanceOf[ArrayType].elementType - a.asInstanceOf[Seq[_]].foreach { + a.asInstanceOf[ArrayData].toArray().foreach { v => list.add(wrap(v, x.getListElementObjectInspector, tpe)) } list @@ -634,7 +645,8 @@ private[hive] trait HiveInspectors { ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) } else { val list = new java.util.ArrayList[Object]() - value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector, dt))) + value.asInstanceOf[ArrayData].toArray() + .foreach(v => list.add(wrap(v, listObjectInspector, dt))) ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) } case Literal(value, MapType(keyType, valueType, _)) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 741c705e2a253..7e3342cc84c0e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -176,13 +176,13 @@ case class ScriptTransformation( val prevLine = curLine curLine = reader.readLine() if (!ioschema.schemaLess) { - new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))) - .asInstanceOf[Array[Any]]) + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + .map(CatalystTypeConverters.convertToCatalyst)) } else { - new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)) - .asInstanceOf[Array[Any]]) + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + .map(CatalystTypeConverters.convertToCatalyst)) } } else { val ret = deserialize() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 8732e9abf8d31..4a13022eddf60 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -431,7 +431,7 @@ private[hive] case class HiveWindowFunction( // if pivotResult is true, we will get a Seq having the same size with the size // of the window frame. At here, we will return the result at the position of // index in the output buffer. - outputBuffer.asInstanceOf[Seq[Any]].get(index) + outputBuffer.asInstanceOf[ArrayData].get(index) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 0330013f5325e..f719f2e06ab63 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -217,7 +217,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { test("wrap / unwrap Array Type") { val dt = ArrayType(dataTypes(0)) - val d = row(0) :: row(0) :: Nil + val d = new GenericArrayData(Array(row(0), row(0))) checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) checkValue(d, From 7bbf02f0bddefd19985372af79e906a38bc528b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= Date: Thu, 30 Jul 2015 18:14:08 +0100 Subject: [PATCH 26/50] [SPARK-9267] [CORE] Retire stringify(Partial)?Value from Accumulators MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit cc srowen Author: François Garillot Closes #7678 from huitseeker/master and squashes the following commits: 5e99f57 [François Garillot] [SPARK-9267][Core] Retire stringify(Partial)?Value from Accumulators --- core/src/main/scala/org/apache/spark/Accumulators.scala | 3 --- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 6 ++---- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 2f4fcac890eef..eb75f26718e19 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -341,7 +341,4 @@ private[spark] object Accumulators extends Logging { } } - def stringifyPartialValue(partialValue: Any): String = "%s".format(partialValue) - - def stringifyValue(value: Any): String = "%s".format(value) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index cdf6078421123..c4fa277c21254 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -916,11 +916,9 @@ class DAGScheduler( // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { val name = acc.name.get - val stringPartialValue = Accumulators.stringifyPartialValue(partialValue) - val stringValue = Accumulators.stringifyValue(acc.value) - stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue) + stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, s"${acc.value}") event.taskInfo.accumulables += - AccumulableInfo(id, name, Some(stringPartialValue), stringValue) + AccumulableInfo(id, name, Some(s"$partialValue"), s"${acc.value}") } } } catch { From 5363ed71568c3e7c082146d654a9c669d692d894 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 Jul 2015 10:30:37 -0700 Subject: [PATCH 27/50] [SPARK-9361] [SQL] Refactor new aggregation code to reduce the times of checking compatibility JIRA: https://issues.apache.org/jira/browse/SPARK-9361 Currently, we call `aggregate.Utils.tryConvert` in many places to check it the logical.Aggregate can be run with new aggregation. But looks like `aggregate.Utils.tryConvert` will cost considerable time to run. We should only call `tryConvert` once and keep it value in `logical.Aggregate` and reuse it. In `org.apache.spark.sql.execution.aggregate.Utils`, the codes involving with `tryConvert` should be moved to catalyst because it actually doesn't deal with execution details. Author: Liang-Chi Hsieh Closes #7677 from viirya/refactor_aggregate and squashes the following commits: babea30 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into refactor_aggregate 9a589d7 [Liang-Chi Hsieh] Fix scala style. 0a91329 [Liang-Chi Hsieh] Refactor new aggregation code to reduce the times to call tryConvert. --- .../expressions/aggregate/interfaces.scala | 4 +- .../expressions/aggregate/utils.scala | 167 ++++++++++++++++++ .../plans/logical/basicOperators.scala | 3 + .../spark/sql/execution/SparkStrategies.scala | 34 ++-- .../spark/sql/execution/aggregate/utils.scala | 144 --------------- 5 files changed, 188 insertions(+), 164 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 9fb7623172e78..d08f553cefe8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -42,7 +42,7 @@ private[sql] case object Partial extends AggregateMode private[sql] case object PartialMerge extends AggregateMode /** - * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers + * An [[AggregateFunction2]] with [[Final]] mode is used to merge aggregation buffers * containing intermediate results for this function and then generate final result. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the final result of this function is returned. @@ -50,7 +50,7 @@ private[sql] case object PartialMerge extends AggregateMode private[sql] case object Final extends AggregateMode /** - * An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this function directly + * An [[AggregateFunction2]] with [[Complete]] mode is used to evaluate this function directly * from original input rows without any partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the final result of this function is returned. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala new file mode 100644 index 0000000000000..4a43318a95490 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.types.{StructType, MapType, ArrayType} + +/** + * Utility functions used by the query planner to convert our plan to new aggregation code path. + */ +object Utils { + // Right now, we do not support complex types in the grouping key schema. + private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { + val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists { + case array: ArrayType => true + case map: MapType => true + case struct: StructType => true + case _ => false + } + + !hasComplexTypes + } + + private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match { + case p: Aggregate if supportsGroupingKeySchema(p) => + val converted = p.transformExpressionsDown { + case expressions.Average(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Average(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Count(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(child), + mode = aggregate.Complete, + isDistinct = false) + + // We do not support multiple COUNT DISTINCT columns for now. + case expressions.CountDistinct(children) if children.length == 1 => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(children.head), + mode = aggregate.Complete, + isDistinct = true) + + case expressions.First(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.First(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Last(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Last(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Max(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Max(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Min(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Min(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Sum(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.SumDistinct(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = true) + } + // Check if there is any expressions.AggregateExpression1 left. + // If so, we cannot convert this plan. + val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => + // For every expressions, check if it contains AggregateExpression1. + expr.find { + case agg: expressions.AggregateExpression1 => true + case other => false + }.isDefined + } + + // Check if there are multiple distinct columns. + val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg + } + }.toSet.toSeq + val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) + val hasMultipleDistinctColumnSets = + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + true + } else { + false + } + + if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None + + case other => None + } + + def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = { + // If the plan cannot be converted, we will do a final round check to see if the original + // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so, + // we need to throw an exception. + val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg.aggregateFunction + } + }.distinct + if (aggregateFunction2s.nonEmpty) { + // For functions implemented based on the new interface, prepare a list of function names. + val invalidFunctions = { + if (aggregateFunction2s.length > 1) { + s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " + + s"and ${aggregateFunction2s.head.nodeName} are" + } else { + s"${aggregateFunction2s.head.nodeName} is" + } + } + val errorMessage = + s"${invalidFunctions} implemented based on the new Aggregate Function " + + s"interface and it cannot be used with functions implemented based on " + + s"the old Aggregate Function interface." + throw new AnalysisException(errorMessage) + } + } + + def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match { + case p: Aggregate => + val converted = doConvert(p) + if (converted.isDefined) { + converted + } else { + checkInvalidAggregateFunction2(p) + None + } + case other => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index ad5af19578f33..a67f8de6b733a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.Utils import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -219,6 +220,8 @@ case class Aggregate( expressions.forall(_.resolved) && childrenResolved && !hasWindowExpressions } + lazy val newAggregation: Option[Aggregate] = Utils.tryConvert(this) + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f3ef066528ff8..52a9b02d373c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, Utils} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} @@ -193,11 +193,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => Nil } - def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = { - aggregate.Utils.tryConvert( - plan, - sqlContext.conf.useSqlAggregate2, - sqlContext.conf.codegenEnabled).isDefined + def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = plan match { + case a: logical.Aggregate => + if (sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled) { + a.newAggregation.isDefined + } else { + Utils.checkInvalidAggregateFunction2(a) + false + } + case _ => false } def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall { @@ -217,12 +221,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case p: logical.Aggregate => - val converted = - aggregate.Utils.tryConvert( - p, - sqlContext.conf.useSqlAggregate2, - sqlContext.conf.codegenEnabled) + case p: logical.Aggregate if sqlContext.conf.useSqlAggregate2 && + sqlContext.conf.codegenEnabled => + val converted = p.newAggregation converted match { case None => Nil // Cannot convert to new aggregation code path. case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) => @@ -377,17 +378,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case e @ logical.Expand(_, _, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil case a @ logical.Aggregate(group, agg, child) => { - val useNewAggregation = - aggregate.Utils.tryConvert( - a, - sqlContext.conf.useSqlAggregate2, - sqlContext.conf.codegenEnabled).isDefined - if (useNewAggregation) { + val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled + if (useNewAggregation && a.newAggregation.isDefined) { // If this logical.Aggregate can be planned to use new aggregation code path // (i.e. it can be planned by the Strategy Aggregation), we will not use the old // aggregation code path. Nil } else { + Utils.checkInvalidAggregateFunction2(a) execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 6549c87752a7d..03635baae4a5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -29,150 +29,6 @@ import org.apache.spark.sql.types.{StructType, MapType, ArrayType} * Utility functions used by the query planner to convert our plan to new aggregation code path. */ object Utils { - // Right now, we do not support complex types in the grouping key schema. - private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { - val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists { - case array: ArrayType => true - case map: MapType => true - case struct: StructType => true - case _ => false - } - - !hasComplexTypes - } - - private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match { - case p: Aggregate if supportsGroupingKeySchema(p) => - val converted = p.transformExpressionsDown { - case expressions.Average(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Average(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Count(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(child), - mode = aggregate.Complete, - isDistinct = false) - - // We do not support multiple COUNT DISTINCT columns for now. - case expressions.CountDistinct(children) if children.length == 1 => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(children.head), - mode = aggregate.Complete, - isDistinct = true) - - case expressions.First(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.First(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Last(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Last(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Max(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Max(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Min(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Min(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Sum(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.SumDistinct(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = true) - } - // Check if there is any expressions.AggregateExpression1 left. - // If so, we cannot convert this plan. - val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => - // For every expressions, check if it contains AggregateExpression1. - expr.find { - case agg: expressions.AggregateExpression1 => true - case other => false - }.isDefined - } - - // Check if there are multiple distinct columns. - val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg - } - }.toSet.toSeq - val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) - val hasMultipleDistinctColumnSets = - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { - true - } else { - false - } - - if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None - - case other => None - } - - private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = { - // If the plan cannot be converted, we will do a final round check to if the original - // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so, - // we need to throw an exception. - val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg.aggregateFunction - } - }.distinct - if (aggregateFunction2s.nonEmpty) { - // For functions implemented based on the new interface, prepare a list of function names. - val invalidFunctions = { - if (aggregateFunction2s.length > 1) { - s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " + - s"and ${aggregateFunction2s.head.nodeName} are" - } else { - s"${aggregateFunction2s.head.nodeName} is" - } - } - val errorMessage = - s"${invalidFunctions} implemented based on the new Aggregate Function " + - s"interface and it cannot be used with functions implemented based on " + - s"the old Aggregate Function interface." - throw new AnalysisException(errorMessage) - } - } - - def tryConvert( - plan: LogicalPlan, - useNewAggregation: Boolean, - codeGenEnabled: Boolean): Option[Aggregate] = plan match { - case p: Aggregate if useNewAggregation && codeGenEnabled => - val converted = tryConvert(p) - if (converted.isDefined) { - converted - } else { - checkInvalidAggregateFunction2(p) - None - } - case p: Aggregate => - checkInvalidAggregateFunction2(p) - None - case other => None - } - def planAggregateWithoutDistinct( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[AggregateExpression2], From e53534655d6198e5b8a507010d26c7b4c4e7f1fd Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Thu, 30 Jul 2015 10:37:53 -0700 Subject: [PATCH 28/50] [SPARK-8297] [YARN] Scheduler backend is not notified in case node fails in YARN This change adds code to notify the scheduler backend when a container dies in YARN. Author: Mridul Muralidharan Author: Marcelo Vanzin Closes #7431 from vanzin/SPARK-8297 and squashes the following commits: 471e4a0 [Marcelo Vanzin] Fix unit test after merge. d4adf4e [Marcelo Vanzin] Merge branch 'master' into SPARK-8297 3b262e8 [Marcelo Vanzin] Merge branch 'master' into SPARK-8297 537da6f [Marcelo Vanzin] Make an expected log less scary. 04dc112 [Marcelo Vanzin] Use driver <-> AM communication to send "remove executor" request. 8855b97 [Marcelo Vanzin] Merge remote-tracking branch 'mridul/fix_yarn_scheduler_bug' into SPARK-8297 687790f [Mridul Muralidharan] Merge branch 'fix_yarn_scheduler_bug' of github.com:mridulm/spark into fix_yarn_scheduler_bug e1b0067 [Mridul Muralidharan] Fix failing testcase, fix merge issue from our 1.3 -> master 9218fcc [Mridul Muralidharan] Fix failing testcase 362d64a [Mridul Muralidharan] Merge branch 'fix_yarn_scheduler_bug' of github.com:mridulm/spark into fix_yarn_scheduler_bug 62ad0cc [Mridul Muralidharan] Merge branch 'fix_yarn_scheduler_bug' of github.com:mridulm/spark into fix_yarn_scheduler_bug bbf8811 [Mridul Muralidharan] Merge branch 'fix_yarn_scheduler_bug' of github.com:mridulm/spark into fix_yarn_scheduler_bug 9ee1307 [Mridul Muralidharan] Fix SPARK-8297 a3a0f01 [Mridul Muralidharan] Fix SPARK-8297 --- .../CoarseGrainedSchedulerBackend.scala | 2 +- .../cluster/YarnSchedulerBackend.scala | 2 ++ .../spark/deploy/yarn/ApplicationMaster.scala | 22 +++++++++---- .../spark/deploy/yarn/YarnAllocator.scala | 32 +++++++++++++++---- .../spark/deploy/yarn/YarnRMClient.scala | 5 ++- .../deploy/yarn/YarnAllocatorSuite.scala | 29 +++++++++++++++++ 6 files changed, 77 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 660702f6e6fd0..bd89160af4ffa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -241,7 +241,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp scheduler.executorLost(executorId, SlaveLost(reason)) listenerBus.post( SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason)) - case None => logError(s"Asked to remove non-existent executor $executorId") + case None => logInfo(s"Asked to remove non-existent executor $executorId") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 074282d1be37d..044f6288fabdd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -109,6 +109,8 @@ private[spark] abstract class YarnSchedulerBackend( case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) + case RemoveExecutor(executorId, reason) => + removeExecutor(executorId, reason) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 44acc7374d024..1d67b3ebb51b7 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -229,7 +229,11 @@ private[spark] class ApplicationMaster( sparkContextRef.compareAndSet(sc, null) } - private def registerAM(_rpcEnv: RpcEnv, uiAddress: String, securityMgr: SecurityManager) = { + private def registerAM( + _rpcEnv: RpcEnv, + driverRef: RpcEndpointRef, + uiAddress: String, + securityMgr: SecurityManager) = { val sc = sparkContextRef.get() val appId = client.getAttemptId().getApplicationId().toString() @@ -246,6 +250,7 @@ private[spark] class ApplicationMaster( RpcAddress(_sparkConf.get("spark.driver.host"), _sparkConf.get("spark.driver.port").toInt), CoarseGrainedSchedulerBackend.ENDPOINT_NAME) allocator = client.register(driverUrl, + driverRef, yarnConf, _sparkConf, if (sc != null) sc.preferredNodeLocationData else Map(), @@ -262,17 +267,20 @@ private[spark] class ApplicationMaster( * * In cluster mode, the AM and the driver belong to same process * so the AMEndpoint need not monitor lifecycle of the driver. + * + * @return A reference to the driver's RPC endpoint. */ private def runAMEndpoint( host: String, port: String, - isClusterMode: Boolean): Unit = { + isClusterMode: Boolean): RpcEndpointRef = { val driverEndpoint = rpcEnv.setupEndpointRef( SparkEnv.driverActorSystemName, RpcAddress(host, port.toInt), YarnSchedulerBackend.ENDPOINT_NAME) amEndpoint = rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpoint, isClusterMode)) + driverEndpoint } private def runDriver(securityMgr: SecurityManager): Unit = { @@ -290,11 +298,11 @@ private[spark] class ApplicationMaster( "Timed out waiting for SparkContext.") } else { rpcEnv = sc.env.rpcEnv - runAMEndpoint( + val driverRef = runAMEndpoint( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) - registerAM(rpcEnv, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) + registerAM(rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) userClassThread.join() } } @@ -302,9 +310,9 @@ private[spark] class ApplicationMaster( private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { val port = sparkConf.getInt("spark.yarn.am.port", 0) rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr) - waitForSparkDriver() + val driverRef = waitForSparkDriver() addAmIpFilter() - registerAM(rpcEnv, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) + registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) // In client mode the actor will stop the reporter thread. reporterThread.join() @@ -428,7 +436,7 @@ private[spark] class ApplicationMaster( } } - private def waitForSparkDriver(): Unit = { + private def waitForSparkDriver(): RpcEndpointRef = { logInfo("Waiting for Spark driver to be reachable.") var driverUp = false val hostport = args.userArgs(0) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 6c103394af098..59caa787b6e20 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -36,6 +36,9 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -52,6 +55,7 @@ import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ */ private[yarn] class YarnAllocator( driverUrl: String, + driverRef: RpcEndpointRef, conf: Configuration, sparkConf: SparkConf, amClient: AMRMClient[ContainerRequest], @@ -88,6 +92,9 @@ private[yarn] class YarnAllocator( // Visible for testing. private[yarn] val executorIdToContainer = new HashMap[String, Container] + private var numUnexpectedContainerRelease = 0L + private val containerIdToExecutorId = new HashMap[ContainerId, String] + // Executor memory in MB. protected val executorMemory = args.executorMemory // Additional memory overhead. @@ -184,6 +191,7 @@ private[yarn] class YarnAllocator( def killExecutor(executorId: String): Unit = synchronized { if (executorIdToContainer.contains(executorId)) { val container = executorIdToContainer.remove(executorId).get + containerIdToExecutorId.remove(container.getId) internalReleaseContainer(container) numExecutorsRunning -= 1 } else { @@ -383,6 +391,7 @@ private[yarn] class YarnAllocator( logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) executorIdToContainer(executorId) = container + containerIdToExecutorId(container.getId) = executorId val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, new HashSet[ContainerId]) @@ -413,12 +422,8 @@ private[yarn] class YarnAllocator( private[yarn] def processCompletedContainers(completedContainers: Seq[ContainerStatus]): Unit = { for (completedContainer <- completedContainers) { val containerId = completedContainer.getContainerId - - if (releasedContainers.contains(containerId)) { - // Already marked the container for release, so remove it from - // `releasedContainers`. - releasedContainers.remove(containerId) - } else { + val alreadyReleased = releasedContainers.remove(containerId) + if (!alreadyReleased) { // Decrement the number of executors running. The next iteration of // the ApplicationMaster's reporting thread will take care of allocating. numExecutorsRunning -= 1 @@ -460,6 +465,18 @@ private[yarn] class YarnAllocator( allocatedContainerToHostMap.remove(containerId) } + + containerIdToExecutorId.remove(containerId).foreach { eid => + executorIdToContainer.remove(eid) + + if (!alreadyReleased) { + // The executor could have gone away (like no route to host, node failure, etc) + // Notify backend about the failure of the executor + numUnexpectedContainerRelease += 1 + driverRef.send(RemoveExecutor(eid, + s"Yarn deallocated the executor $eid (container $containerId)")) + } + } } } @@ -467,6 +484,9 @@ private[yarn] class YarnAllocator( releasedContainers.add(container.getId()) amClient.releaseAssignedContainer(container.getId()) } + + private[yarn] def getNumUnexpectedContainerRelease = numUnexpectedContainerRelease + } private object YarnAllocator { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 7f533ee55e8bb..4999f9c06210a 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -33,6 +33,7 @@ import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.hadoop.yarn.webapp.util.WebAppUtils import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo import org.apache.spark.util.Utils @@ -56,6 +57,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg */ def register( driverUrl: String, + driverRef: RpcEndpointRef, conf: YarnConfiguration, sparkConf: SparkConf, preferredNodeLocations: Map[String, Set[SplitInfo]], @@ -73,7 +75,8 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) registered = true } - new YarnAllocator(driverUrl, conf, sparkConf, amClient, getAttemptId(), args, securityMgr) + new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), args, + securityMgr) } /** diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 37a789fcd375b..58318bf9bcc08 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -27,10 +27,14 @@ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.mockito.Mockito._ + import org.apache.spark.{SecurityManager, SparkFunSuite} import org.apache.spark.SparkConf import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.YarnAllocator._ +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo class MockResolver extends DNSToSwitchMapping { @@ -90,6 +94,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter "--class", "SomeClass") new YarnAllocator( "not used", + mock(classOf[RpcEndpointRef]), conf, sparkConf, rmClient, @@ -230,6 +235,30 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumPendingAllocate should be (1) } + test("lost executor removed from backend") { + val handler = createAllocator(4) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (4) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map()) + + val statuses = Seq(container1, container2).map { c => + ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Failed", -1) + } + handler.updateResourceRequests() + handler.processCompletedContainers(statuses.toSeq) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (2) + handler.getNumExecutorsFailed should be (2) + handler.getNumUnexpectedContainerRelease should be (2) + } + test("memory exceeded diagnostic regexes") { val diagnostics = "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " + From ab78b1d2a6ce26833ea3878a63921efd805a3737 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 30 Jul 2015 10:40:04 -0700 Subject: [PATCH 29/50] [SPARK-9388] [YARN] Make executor info log messages easier to read. Author: Marcelo Vanzin Closes #7706 from vanzin/SPARK-9388 and squashes the following commits: 028b990 [Marcelo Vanzin] Single log statement. 3c5fb6a [Marcelo Vanzin] YARN not Yarn. 5bcd7a0 [Marcelo Vanzin] [SPARK-9388] [yarn] Make executor info log messages easier to read. --- .../scala/org/apache/spark/deploy/yarn/Client.scala | 2 +- .../apache/spark/deploy/yarn/ExecutorRunnable.scala | 13 ++++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index bc28ce5eeae72..4ac3397f1ad28 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -767,7 +767,7 @@ private[spark] class Client( amContainer.setCommands(printableCommands) logDebug("===============================================================================") - logDebug("Yarn AM launch context:") + logDebug("YARN AM launch context:") logDebug(s" user class: ${Option(args.userClass).getOrElse("N/A")}") logDebug(" env:") launchEnv.foreach { case (k, v) => logDebug(s" $k -> $v") } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 78e27fb7f3337..52580deb372c2 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -86,10 +86,17 @@ class ExecutorRunnable( val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores, appId, localResources) - logInfo(s"Setting up executor with environment: $env") - logInfo("Setting up executor with commands: " + commands) - ctx.setCommands(commands) + logInfo(s""" + |=============================================================================== + |YARN executor launch context: + | env: + |${env.map { case (k, v) => s" $k -> $v\n" }.mkString} + | command: + | ${commands.mkString(" ")} + |=============================================================================== + """.stripMargin) + ctx.setCommands(commands) ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) // If external shuffle service is enabled, register with the Yarn shuffle service already From 520ec0ff9db75267f627dc4615b2316a1a3d44d7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 30 Jul 2015 10:45:32 -0700 Subject: [PATCH 30/50] [SPARK-8850] [SQL] Enable Unsafe mode by default This pull request enables Unsafe mode by default in Spark SQL. In order to do this, we had to fix a number of small issues: **List of fixed blockers**: - [x] Make some default buffer sizes configurable so that HiveCompatibilitySuite can run properly (#7741). - [x] Memory leak on grouped aggregation of empty input (fixed by #7560 to fix this) - [x] Update planner to also check whether codegen is enabled before planning unsafe operators. - [x] Investigate failing HiveThriftBinaryServerSuite test. This turns out to be caused by a ClassCastException that occurs when Exchange tries to apply an interpreted RowOrdering to an UnsafeRow when range partitioning an RDD. This could be fixed by #7408, but a shorter-term fix is to just skip the Unsafe exchange path when RangePartitioner is used. - [x] Memory leak exceptions masking exceptions that actually caused tasks to fail (will be fixed by #7603). - [x] ~~https://issues.apache.org/jira/browse/SPARK-9162, to implement code generation for ScalaUDF. This is necessary for `UDFSuite` to pass. For now, I've just ignored this test in order to try to find other problems while we wait for a fix.~~ This is no longer necessary as of #7682. - [x] Memory leaks from Limit after UnsafeExternalSort cause the memory leak detector to fail tests. This is a huge problem in the HiveCompatibilitySuite (fixed by f4ac642a4e5b2a7931c5e04e086bb10e263b1db6). - [x] Tests in `AggregationQuerySuite` are failing due to NaN-handling issues in UnsafeRow, which were fixed in #7736. - [x] `org.apache.spark.sql.ColumnExpressionSuite.rand` needs to be updated so that the planner check also matches `TungstenProject`. - [x] After having lowered the buffer sizes to 4MB so that most of HiveCompatibilitySuite runs: - [x] Wrong answer in `join_1to1` (fixed by #7680) - [x] Wrong answer in `join_nulls` (fixed by #7680) - [x] Managed memory OOM / leak in `lateral_view` - [x] Seems to hang indefinitely in `partcols1`. This might be a deadlock in script transformation or a bug in error-handling code? The hang was fixed by #7710. - [x] Error while freeing memory in `partcols1`: will be fixed by #7734. - [x] After fixing the `partcols1` hang, it appears that a number of later tests have issues as well. - [x] Fix thread-safety bug in codegen fallback expression evaluation (#7759). Author: Josh Rosen Closes #7564 from JoshRosen/unsafe-by-default and squashes the following commits: 83c0c56 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-by-default f4cc859 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-by-default 963f567 [Josh Rosen] Reduce buffer size for R tests d6986de [Josh Rosen] Lower page size in PySpark tests 013b9da [Josh Rosen] Also match TungstenProject in checkNumProjects 5d0b2d3 [Josh Rosen] Add task completion callback to avoid leak in limit after sort ea250da [Josh Rosen] Disable unsafe Exchange path when RangePartitioning is used 715517b [Josh Rosen] Enable Unsafe by default --- R/run-tests.sh | 2 +- .../unsafe/sort/UnsafeExternalSorter.java | 14 +++++++++++++ python/pyspark/java_gateway.py | 6 +++++- .../scala/org/apache/spark/sql/SQLConf.scala | 2 +- .../apache/spark/sql/execution/Exchange.scala | 7 ++++++- .../spark/sql/ColumnExpressionSuite.scala | 3 ++- .../execution/UnsafeExternalSortSuite.scala | 20 +------------------ 7 files changed, 30 insertions(+), 24 deletions(-) diff --git a/R/run-tests.sh b/R/run-tests.sh index e82ad0ba2cd06..18a1e13bdc655 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 $FWDIR/../bin/sparkR --conf spark.buffer.pageSize=4m --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) if [[ $FAILED != 0 ]]; then diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index c21990f4e4778..866e0b4151577 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -20,6 +20,9 @@ import java.io.IOException; import java.util.LinkedList; +import scala.runtime.AbstractFunction0; +import scala.runtime.BoxedUnit; + import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -90,6 +93,17 @@ public UnsafeExternalSorter( this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "64m"); initializeForWriting(); + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the sorter's output (e.g. sort followed by limit). + taskContext.addOnCompleteCallback(new AbstractFunction0() { + @Override + public BoxedUnit apply() { + freeMemory(); + return null; + } + }); } // TODO: metrics tracking + integration with shuffle write metrics diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 90cd342a6cf7f..60be85e53e2aa 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -52,7 +52,11 @@ def launch_gateway(): script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") if os.environ.get("SPARK_TESTING"): - submit_args = "--conf spark.ui.enabled=false " + submit_args + submit_args = ' '.join([ + "--conf spark.ui.enabled=false", + "--conf spark.buffer.pageSize=4mb", + submit_args + ]) command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) # Start a socket that will be used by PythonGatewayServer to communicate its port to us diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 2564bbd2077bf..6644e85d4a037 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -229,7 +229,7 @@ private[spark] object SQLConf { " a specific query.") val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", - defaultValue = Some(false), + defaultValue = Some(true), doc = "When true, use the new optimized Tungsten physical execution backend.") val DIALECT = stringConf( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 41a0c519ba527..70e5031fb63c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -47,7 +47,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una override def canProcessSafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = { + // Do not use the Unsafe path if we are using a RangePartitioning, since this may lead to + // an interpreted RowOrdering being applied to an UnsafeRow, which will lead to + // ClassCastExceptions at runtime. This check can be removed after SPARK-9054 is fixed. + !newPartitioning.isInstanceOf[RangePartitioning] + } /** * Determines whether records must be defensively copied before being sent to the shuffle. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 5c1102410879a..eb64684ae0fd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.scalatest.Matchers._ -import org.apache.spark.sql.execution.Project +import org.apache.spark.sql.execution.{Project, TungstenProject} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.test.SQLTestUtils @@ -538,6 +538,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { val projects = df.queryExecution.executedPlan.collect { case project: Project => project + case tungstenProject: TungstenProject => tungstenProject } assert(projects.size === expectedNumProjects) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 7a4baa9e4a49d..138636b0c65b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -36,10 +36,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) } - ignore("sort followed by limit should not leak memory") { - // TODO: this test is going to fail until we implement a proper iterator interface - // with a close() method. - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") + test("sort followed by limit") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), @@ -48,21 +45,6 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { ) } - test("sort followed by limit") { - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") - try { - checkThatPlansAgree( - (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), - (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), - sortAnswers = false - ) - } finally { - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") - - } - } - test("sorting does not crash for large inputs") { val sortOrder = 'a.asc :: Nil val stringLength = 1024 * 1024 * 2 From 06b6a074fb224b3fe23922bdc89fc5f7c2ffaaf6 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 30 Jul 2015 10:46:26 -0700 Subject: [PATCH 31/50] [SPARK-9437] [CORE] avoid overflow in SizeEstimator https://issues.apache.org/jira/browse/SPARK-9437 Author: Imran Rashid Closes #7750 from squito/SPARK-9437_size_estimator_overflow and squashes the following commits: 29493f1 [Imran Rashid] prevent another potential overflow bc1cb82 [Imran Rashid] avoid overflow --- .../main/scala/org/apache/spark/util/SizeEstimator.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 7d84468f62ab1..14b1f2a17e707 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -217,10 +217,10 @@ object SizeEstimator extends Logging { var arrSize: Long = alignSize(objectSize + INT_SIZE) if (elementClass.isPrimitive) { - arrSize += alignSize(length * primitiveSize(elementClass)) + arrSize += alignSize(length.toLong * primitiveSize(elementClass)) state.size += arrSize } else { - arrSize += alignSize(length * pointerSize) + arrSize += alignSize(length.toLong * pointerSize) state.size += arrSize if (length <= ARRAY_SIZE_FOR_SAMPLING) { @@ -336,7 +336,7 @@ object SizeEstimator extends Logging { // hg.openjdk.java.net/jdk8/jdk8/hotspot/file/tip/src/share/vm/classfile/classFileParser.cpp var alignedSize = shellSize for (size <- fieldSizes if sizeCount(size) > 0) { - val count = sizeCount(size) + val count = sizeCount(size).toLong // If there are internal gaps, smaller field can fit in. alignedSize = math.max(alignedSize, alignSizeUp(shellSize, size) + size * count) shellSize += size * count From 6d94bf6ac10ac851636c62439f8f2737f3526a2a Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 30 Jul 2015 11:13:15 -0700 Subject: [PATCH 32/50] [SPARK-8174] [SPARK-8175] [SQL] function unix_timestamp, from_unixtime unix_timestamp(): long Gets current Unix timestamp in seconds. unix_timestamp(string|date): long Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), using the default timezone and the default locale, return null if fail: unix_timestamp('2009-03-20 11:30:01') = 1237573801 unix_timestamp(string date, string pattern): long Convert time string with given pattern (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) to Unix time stamp (in seconds), return null if fail: unix_timestamp('2009-03-20', 'yyyy-MM-dd') = 1237532400. from_unixtime(bigint unixtime[, string format]): string Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string representing the timestamp of that moment in the current system time zone in the format of "1970-01-01 00:00:00". Jira: https://issues.apache.org/jira/browse/SPARK-8174 https://issues.apache.org/jira/browse/SPARK-8175 Author: Daoyuan Wang Closes #7644 from adrian-wang/udfunixtime and squashes the following commits: 2fe20c4 [Daoyuan Wang] util.Date ea2ec16 [Daoyuan Wang] use util.Date for better performance a2cf929 [Daoyuan Wang] doc return null instead of 0 f6f070a [Daoyuan Wang] address comments from davies 6a4cbb3 [Daoyuan Wang] temp 56ded53 [Daoyuan Wang] rebase and address comments 14a8b37 [Daoyuan Wang] function unix_timestamp, from_unixtime --- .../catalyst/analysis/FunctionRegistry.scala | 2 + .../expressions/datetimeFunctions.scala | 219 +++++++++++++++++- .../expressions/DateExpressionsSuite.scala | 59 ++++- .../org/apache/spark/sql/functions.scala | 42 ++++ .../apache/spark/sql/DateFunctionsSuite.scala | 56 +++++ 5 files changed, 374 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 378df4f57d9e2..d663f12bc6d0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -211,6 +211,7 @@ object FunctionRegistry { expression[DayOfMonth]("day"), expression[DayOfYear]("dayofyear"), expression[DayOfMonth]("dayofmonth"), + expression[FromUnixTime]("from_unixtime"), expression[Hour]("hour"), expression[LastDay]("last_day"), expression[Minute]("minute"), @@ -218,6 +219,7 @@ object FunctionRegistry { expression[NextDay]("next_day"), expression[Quarter]("quarter"), expression[Second]("second"), + expression[UnixTimestamp]("unix_timestamp"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index efecb771f2f5d..a5e6249e438d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.Date import java.text.SimpleDateFormat import java.util.{Calendar, TimeZone} @@ -28,6 +27,8 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import scala.util.Try + /** * Returns the current date at the start of query evaluation. * All calls of current_date within the same query return the same value. @@ -236,20 +237,232 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx override protected def nullSafeEval(timestamp: Any, format: Any): Any = { val sdf = new SimpleDateFormat(format.toString) - UTF8String.fromString(sdf.format(new Date(timestamp.asInstanceOf[Long] / 1000))) + UTF8String.fromString(sdf.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val sdf = classOf[SimpleDateFormat].getName defineCodeGen(ctx, ev, (timestamp, format) => { s"""UTF8String.fromString((new $sdf($format.toString())) - .format(new java.sql.Date($timestamp / 1000)))""" + .format(new java.util.Date($timestamp / 1000)))""" }) } override def prettyName: String = "date_format" } +/** + * Converts time string with given pattern + * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) + * to Unix time stamp (in seconds), returns null if fail. + * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. + * If the second parameter is missing, use "yyyy-MM-dd HH:mm:ss". + * If no parameters provided, the first parameter will be current_timestamp. + * If the first parameter is a Date or Timestamp instead of String, we will ignore the + * second parameter. + */ +case class UnixTimestamp(timeExp: Expression, format: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = timeExp + override def right: Expression = format + + def this(time: Expression) = { + this(time, Literal("yyyy-MM-dd HH:mm:ss")) + } + + def this() = { + this(CurrentTimestamp()) + } + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(StringType, DateType, TimestampType), StringType) + + override def dataType: DataType = LongType + + private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] + + override def eval(input: InternalRow): Any = { + val t = left.eval(input) + if (t == null) { + null + } else { + left.dataType match { + case DateType => + DateTimeUtils.daysToMillis(t.asInstanceOf[Int]) / 1000L + case TimestampType => + t.asInstanceOf[Long] / 1000000L + case StringType if right.foldable => + if (constFormat != null) { + Try(new SimpleDateFormat(constFormat.toString).parse( + t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) + } else { + null + } + case StringType => + val f = format.eval(input) + if (f == null) { + null + } else { + val formatString = f.asInstanceOf[UTF8String].toString + Try(new SimpleDateFormat(formatString).parse( + t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) + } + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + left.dataType match { + case StringType if right.foldable => + val sdf = classOf[SimpleDateFormat].getName + val fString = if (constFormat == null) null else constFormat.toString + val formatter = ctx.freshName("formatter") + if (fString == null) { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + val eval1 = left.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + try { + $sdf $formatter = new $sdf("$fString"); + ${ev.primitive} = + $formatter.parse(${eval1.primitive}.toString()).getTime() / 1000L; + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + } + } + """ + } + case StringType => + val sdf = classOf[SimpleDateFormat].getName + nullSafeCodeGen(ctx, ev, (string, format) => { + s""" + try { + ${ev.primitive} = + (new $sdf($format.toString())).parse($string.toString()).getTime() / 1000L; + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + } + """ + }) + case TimestampType => + val eval1 = left.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = ${eval1.primitive} / 1000000L; + } + """ + case DateType => + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val eval1 = left.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $dtu.daysToMillis(${eval1.primitive}) / 1000L; + } + """ + } + } +} + +/** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. If the format is missing, using format like "1970-01-01 00:00:00". + * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. + */ +case class FromUnixTime(sec: Expression, format: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = sec + override def right: Expression = format + + def this(unix: Expression) = { + this(unix, Literal("yyyy-MM-dd HH:mm:ss")) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) + + private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] + + override def eval(input: InternalRow): Any = { + val time = left.eval(input) + if (time == null) { + null + } else { + if (format.foldable) { + if (constFormat == null) { + null + } else { + Try(UTF8String.fromString(new SimpleDateFormat(constFormat.toString).format( + new java.util.Date(time.asInstanceOf[Long] * 1000L)))).getOrElse(null) + } + } else { + val f = format.eval(input) + if (f == null) { + null + } else { + Try(UTF8String.fromString(new SimpleDateFormat( + f.asInstanceOf[UTF8String].toString).format(new java.util.Date( + time.asInstanceOf[Long] * 1000L)))).getOrElse(null) + } + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val sdf = classOf[SimpleDateFormat].getName + if (format.foldable) { + if (constFormat == null) { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + val t = left.gen(ctx) + s""" + ${t.code} + boolean ${ev.isNull} = ${t.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + try { + ${ev.primitive} = UTF8String.fromString(new $sdf("${constFormat.toString}").format( + new java.util.Date(${t.primitive} * 1000L))); + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + } + } + """ + } + } else { + nullSafeCodeGen(ctx, ev, (seconds, f) => { + s""" + try { + ${ev.primitive} = UTF8String.fromString((new $sdf($f.toString())).format( + new java.util.Date($seconds * 1000L))); + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + }""".stripMargin + }) + } + } + +} + /** * Returns the last day of the month which the date belongs to. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index aca8d6eb3500c..e1387f945ffa4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -22,8 +22,9 @@ import java.text.SimpleDateFormat import java.util.Calendar import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types.{StringType, TimestampType, DateType} +import org.apache.spark.sql.types._ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -303,4 +304,60 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null) } + + test("from_unixtime") { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2) + checkEvaluation( + FromUnixTime(Literal(0L), Literal("yyyy-MM-dd HH:mm:ss")), sdf1.format(new Timestamp(0))) + checkEvaluation(FromUnixTime( + Literal(1000L), Literal("yyyy-MM-dd HH:mm:ss")), sdf1.format(new Timestamp(1000000))) + checkEvaluation( + FromUnixTime(Literal(-1000L), Literal(fmt2)), sdf2.format(new Timestamp(-1000000))) + checkEvaluation( + FromUnixTime(Literal.create(null, LongType), Literal.create(null, StringType)), null) + checkEvaluation( + FromUnixTime(Literal.create(null, LongType), Literal("yyyy-MM-dd HH:mm:ss")), null) + checkEvaluation(FromUnixTime(Literal(1000L), Literal.create(null, StringType)), null) + checkEvaluation( + FromUnixTime(Literal(0L), Literal("not a valid format")), null) + } + + test("unix_timestamp") { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2) + val fmt3 = "yy-MM-dd" + val sdf3 = new SimpleDateFormat(fmt3) + val date1 = Date.valueOf("2015-07-24") + checkEvaluation( + UnixTimestamp(Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss")), 0L) + checkEvaluation(UnixTimestamp( + Literal(sdf1.format(new Timestamp(1000000))), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) + checkEvaluation( + UnixTimestamp(Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) + checkEvaluation( + UnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss")), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1)) / 1000L) + checkEvaluation( + UnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2)), -1000L) + checkEvaluation(UnixTimestamp( + Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3)), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24"))) / 1000L) + val t1 = UnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + val t2 = UnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + assert(t2 - t1 <= 1) + checkEvaluation( + UnixTimestamp(Literal.create(null, DateType), Literal.create(null, StringType)), null) + checkEvaluation( + UnixTimestamp(Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss")), null) + checkEvaluation(UnixTimestamp( + Literal(date1), Literal.create(null, StringType)), date1.getTime / 1000L) + checkEvaluation( + UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a2fece62f61f9..3f440e062eb96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2110,6 +2110,48 @@ object functions { */ def weekofyear(columnName: String): Column = weekofyear(Column(columnName)) + /** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. + * @group datetime_funcs + * @since 1.5.0 + */ + def from_unixtime(ut: Column): Column = FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. + * @group datetime_funcs + * @since 1.5.0 + */ + def from_unixtime(ut: Column, f: String): Column = FromUnixTime(ut.expr, Literal(f)) + + /** + * Gets current Unix timestamp in seconds. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(): Column = UnixTimestamp(CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), + * using the default timezone and the default locale, return null if fail. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(s: Column): Column = UnixTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Convert time string with given pattern + * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) + * to Unix time stamp (in seconds), return null if fail. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Collection functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 07eb6e4a8d8cd..df4cb57ac5b21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -228,4 +228,60 @@ class DateFunctionsSuite extends QueryTest { Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-07-30")))) } + test("from_unixtime") { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2) + val fmt3 = "yy-MM-dd HH-mm-ss" + val sdf3 = new SimpleDateFormat(fmt3) + val df = Seq((1000, "yyyy-MM-dd HH:mm:ss.SSS"), (-1000, "yy-MM-dd HH-mm-ss")).toDF("a", "b") + checkAnswer( + df.select(from_unixtime(col("a"))), + Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) + checkAnswer( + df.select(from_unixtime(col("a"), fmt2)), + Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) + checkAnswer( + df.select(from_unixtime(col("a"), fmt3)), + Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr("from_unixtime(a)"), + Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr(s"from_unixtime(a, '$fmt2')"), + Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr(s"from_unixtime(a, '$fmt3')"), + Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + } + + test("unix_timestamp") { + val date1 = Date.valueOf("2015-07-24") + val date2 = Date.valueOf("2015-07-25") + val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3") + val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2") + val s1 = "2015/07/24 10:00:00.5" + val s2 = "2015/07/25 02:02:02.6" + val ss1 = "2015-07-24 10:00:00" + val ss2 = "2015-07-25 02:02:02" + val fmt = "yyyy/MM/dd HH:mm:ss.S" + val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") + checkAnswer(df.select(unix_timestamp(col("ts"))), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("ss"))), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("d"), fmt)), Seq( + Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("s"), fmt)), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr("unix_timestamp(ts)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr("unix_timestamp(ss)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"unix_timestamp(d, '$fmt')"), Seq( + Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + } + } From a20e743fb863de809863652931bc982aac2d1f86 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 30 Jul 2015 13:09:43 -0700 Subject: [PATCH 33/50] [SPARK-9460] Fix prefix generation for UTF8String. Previously we could be getting garbage data if the number of bytes is 0, or on JVMs that are 4 byte aligned, or when compressedoops is on. Author: Reynold Xin Closes #7789 from rxin/utf8string and squashes the following commits: 86ffa3e [Reynold Xin] Mask out data outside of valid range. 4d647ed [Reynold Xin] Mask out data. c6e8794 [Reynold Xin] [SPARK-9460] Fix prefix generation for UTF8String. --- .../apache/spark/unsafe/types/UTF8String.java | 36 +++++++++++++++++-- .../spark/unsafe/types/UTF8StringSuite.java | 8 +++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 57522003ba2ba..c38953f65d7d7 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -65,6 +65,19 @@ public static UTF8String fromBytes(byte[] bytes) { } } + /** + * Creates an UTF8String from byte array, which should be encoded in UTF-8. + * + * Note: `bytes` will be hold by returned UTF8String. + */ + public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { + if (bytes != null) { + return new UTF8String(bytes, BYTE_ARRAY_OFFSET + offset, numBytes); + } else { + return null; + } + } + /** * Creates an UTF8String from String. */ @@ -89,10 +102,10 @@ public static UTF8String blankString(int length) { return fromBytes(spaces); } - protected UTF8String(Object base, long offset, int size) { + protected UTF8String(Object base, long offset, int numBytes) { this.base = base; this.offset = offset; - this.numBytes = size; + this.numBytes = numBytes; } /** @@ -141,7 +154,24 @@ public int numChars() { * Returns a 64-bit integer that can be used as the prefix used in sorting. */ public long getPrefix() { - long p = PlatformDependent.UNSAFE.getLong(base, offset); + // Since JVMs are either 4-byte aligned or 8-byte aligned, we check the size of the string. + // If size is 0, just return 0. + // If size is between 0 and 4 (inclusive), assume data is 4-byte aligned under the hood and + // use a getInt to fetch the prefix. + // If size is greater than 4, assume we have at least 8 bytes of data to fetch. + // After getting the data, we use a mask to mask out data that is not part of the string. + long p; + if (numBytes >= 8) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + } else if (numBytes > 4) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + p = p & ((1L << numBytes * 8) - 1); + } else if (numBytes > 0) { + p = (long) PlatformDependent.UNSAFE.getInt(base, offset); + p = p & ((1L << numBytes * 8) - 1); + } else { + p = 0; + } p = java.lang.Long.reverseBytes(p); return p; } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 42e09e435a412..f2cc19ca6b172 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -71,6 +71,14 @@ public void prefix() { fromString("abbbbbbbbbbbasdf").getPrefix() - fromString("bbbbbbbbbbbbasdf").getPrefix() < 0); assertTrue(fromString("").getPrefix() - fromString("a").getPrefix() < 0); assertTrue(fromString("你好").getPrefix() - fromString("世界").getPrefix() > 0); + + byte[] buf1 = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + byte[] buf2 = {1, 2, 3}; + UTF8String str1 = UTF8String.fromBytes(buf1, 0, 3); + UTF8String str2 = UTF8String.fromBytes(buf1, 0, 8); + UTF8String str3 = UTF8String.fromBytes(buf2); + assertTrue(str1.getPrefix() - str2.getPrefix() < 0); + assertEquals(str1.getPrefix(), str3.getPrefix()); } @Test From d8cfd531c7c50c9b00ab546be458f44f84c386ae Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 30 Jul 2015 13:17:54 -0700 Subject: [PATCH 34/50] [SPARK-5567] [MLLIB] Add predict method to LocalLDAModel jkbradley hhbyyh Adds `topicDistributions` to LocalLDAModel. Please review after #7757 is merged. Author: Feynman Liang Closes #7760 from feynmanliang/SPARK-5567-predict-in-LDA and squashes the following commits: 0ad1134 [Feynman Liang] Remove println 27b3877 [Feynman Liang] Code review fixes 6bfb87c [Feynman Liang] Remove extra newline 476f788 [Feynman Liang] Fix checks and doc for variationalInference 061780c [Feynman Liang] Code review cleanup 3be2947 [Feynman Liang] Rename topicDistribution -> topicDistributions 2a821a6 [Feynman Liang] Add predict methods to LocalLDAModel --- .../spark/mllib/clustering/LDAModel.scala | 42 +++++++++++-- .../spark/mllib/clustering/LDAOptimizer.scala | 5 +- .../spark/mllib/clustering/LDASuite.scala | 63 +++++++++++++++++++ 3 files changed, 102 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index ece28848aa02c..6cfad3fbbdb87 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -186,7 +186,6 @@ abstract class LDAModel private[clustering] extends Saveable { * This model stores only the inferred topics. * It may be used for computing topics for new documents, but it may give less accurate answers * than the [[DistributedLDAModel]]. - * * @param topics Inferred topics (vocabSize x k matrix). */ @Experimental @@ -221,9 +220,6 @@ class LocalLDAModel private[clustering] ( // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? - // TODO: - // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? - /** * Calculate the log variational bound on perplexity. See Equation (16) in original Online * LDA paper. @@ -269,7 +265,7 @@ class LocalLDAModel private[clustering] ( // by topic (columns of lambda) val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t - var score = documents.filter(_._2.numActives > 0).map { case (id: Long, termCounts: Vector) => + var score = documents.filter(_._2.numNonzeros > 0).map { case (id: Long, termCounts: Vector) => var docScore = 0.0D val (gammad: BDV[Double], _) = OnlineLDAOptimizer.variationalTopicInference( termCounts, exp(Elogbeta), brzAlpha, gammaShape, k) @@ -277,7 +273,7 @@ class LocalLDAModel private[clustering] ( // E[log p(doc | theta, beta)] termCounts.foreachActive { case (idx, count) => - docScore += LDAUtils.logSumExp(Elogthetad + Elogbeta(idx, ::).t) + docScore += count * LDAUtils.logSumExp(Elogthetad + Elogbeta(idx, ::).t) } // E[log p(theta | alpha) - log q(theta | gamma)]; assumes alpha is a vector docScore += sum((brzAlpha - gammad) :* Elogthetad) @@ -297,6 +293,40 @@ class LocalLDAModel private[clustering] ( score } + /** + * Predicts the topic mixture distribution for each document (often called "theta" in the + * literature). Returns a vector of zeros for an empty document. + * + * This uses a variational approximation following Hoffman et al. (2010), where the approximate + * distribution is called "gamma." Technically, this method returns this approximation "gamma" + * for each document. + * @param documents documents to predict topic mixture distributions for + * @return An RDD of (document ID, topic mixture distribution for document) + */ + // TODO: declare in LDAModel and override once implemented in DistributedLDAModel + def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = { + // Double transpose because dirichletExpectation normalizes by row and we need to normalize + // by topic (columns of lambda) + val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t) + val docConcentrationBrz = this.docConcentration.toBreeze + val gammaShape = this.gammaShape + val k = this.k + + documents.map { case (id: Long, termCounts: Vector) => + if (termCounts.numNonzeros == 0) { + (id, Vectors.zeros(k)) + } else { + val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, + expElogbeta, + docConcentrationBrz, + gammaShape, + k) + (id, Vectors.dense(normalize(gamma, 1.0).toArray)) + } + } + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 4b90fbdf0ce7e..9dbec41efeada 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -394,7 +394,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { val gammaShape = this.gammaShape val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs => - val nonEmptyDocs = docs.filter(_._2.numActives > 0) + val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) val stat = BDM.zeros[Double](k, vocabSize) var gammaPart = List[BDV[Double]]() @@ -461,7 +461,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer { private[clustering] object OnlineLDAOptimizer { /** * Uses variational inference to infer the topic distribution `gammad` given the term counts - * for a document. `termCounts` must be non-empty, otherwise Breeze will throw a BLAS error. + * for a document. `termCounts` must contain at least one non-zero entry, otherwise Breeze will + * throw a BLAS error. * * An optimization (Lee, Seung: Algorithms for non-negative matrix factorization, NIPS 2001) * avoids explicit computation of variational parameter `phi`. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 61d2edfd9fb5f..d74482d3a7598 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -242,6 +242,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { val alpha = 0.01 val eta = 0.01 val gammaShape = 100 + // obtained from LDA model trained in gensim, see below val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) @@ -281,6 +282,68 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(ldaModel.logPerplexity(docs) ~== -3.690D relTol 1E-3D) } + test("LocalLDAModel predict") { + val k = 2 + val vocabSize = 6 + val alpha = 0.01 + val eta = 0.01 + val gammaShape = 100 + // obtained from LDA model trained in gensim, see below + val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( + 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, + 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) + + def toydata: Array[(Long, Vector)] = Array( + Vectors.sparse(6, Array(0, 1), Array(1, 1)), + Vectors.sparse(6, Array(1, 2), Array(1, 1)), + Vectors.sparse(6, Array(0, 2), Array(1, 1)), + Vectors.sparse(6, Array(3, 4), Array(1, 1)), + Vectors.sparse(6, Array(3, 5), Array(1, 1)), + Vectors.sparse(6, Array(4, 5), Array(1, 1)) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + val docs = sc.parallelize(toydata) + + val ldaModel: LocalLDAModel = new LocalLDAModel( + topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + + /* Verify results using gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + print(list(lda.get_document_topics(corpus))) + > [[(0, 0.99504950495049516)], [(0, 0.99504950495049516)], + > [(0, 0.99504950495049516)], [(1, 0.99504950495049516)], + > [(1, 0.99504950495049516)], [(1, 0.99504950495049516)]] + */ + + val expectedPredictions = List( + (0, 0.99504), (0, 0.99504), + (0, 0.99504), (1, 0.99504), + (1, 0.99504), (1, 0.99504)) + + val actualPredictions = ldaModel.topicDistributions(docs).map { case (id, topics) => + // convert results to expectedPredictions format, which only has highest probability topic + val topicsBz = topics.toBreeze.toDenseVector + (id, (argmax(topicsBz), max(topicsBz))) + }.sortByKey() + .values + .collect() + + expectedPredictions.zip(actualPredictions).forall { case (expected, actual) => + expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D) + } + } + test("OnlineLDAOptimizer with asymmetric prior") { def toydata: Array[(Long, Vector)] = Array( Vectors.sparse(6, Array(0, 1), Array(1, 1)), From 1abf7dc16ca1ba1777fe874c8b81fe6f2b0a6de5 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 30 Jul 2015 13:21:46 -0700 Subject: [PATCH 35/50] [SPARK-8186] [SPARK-8187] [SPARK-8194] [SPARK-8198] [SPARK-9133] [SPARK-9290] [SQL] functions: date_add, date_sub, add_months, months_between, time-interval calculation This PR is based on #7589 , thanks to adrian-wang Added SQL function date_add, date_sub, add_months, month_between, also add a rule for add/subtract of date/timestamp and interval. Closes #7589 cc rxin Author: Daoyuan Wang Author: Davies Liu Closes #7754 from davies/date_add and squashes the following commits: e8c633a [Davies Liu] Merge branch 'master' of github.com:apache/spark into date_add 9e8e085 [Davies Liu] Merge branch 'master' of github.com:apache/spark into date_add 6224ce4 [Davies Liu] fix conclict bd18cd4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into date_add e47ff2c [Davies Liu] add python api, fix date functions 01943d0 [Davies Liu] Merge branch 'master' into date_add 522e91a [Daoyuan Wang] fix e8a639a [Daoyuan Wang] fix 42df486 [Daoyuan Wang] fix style 87c4b77 [Daoyuan Wang] function add_months, months_between and some fixes 1a68e03 [Daoyuan Wang] poc of time interval calculation c506661 [Daoyuan Wang] function date_add , date_sub --- python/pyspark/sql/functions.py | 76 ++++++- .../catalyst/analysis/FunctionRegistry.scala | 4 + .../catalyst/analysis/HiveTypeCoercion.scala | 22 ++ .../expressions/datetimeFunctions.scala | 155 ++++++++++++- .../sql/catalyst/util/DateTimeUtils.scala | 139 ++++++++++++ .../analysis/HiveTypeCoercionSuite.scala | 30 +++ .../expressions/DateExpressionsSuite.scala | 176 +++++++++------ .../catalyst/util/DateTimeUtilsSuite.scala | 205 +++++++++++------- .../org/apache/spark/sql/functions.scala | 29 +++ .../apache/spark/sql/DateFunctionsSuite.scala | 117 ++++++++++ 10 files changed, 791 insertions(+), 162 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d930f7db25d25..a7295e25f0aa5 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -59,7 +59,7 @@ __all__ += ['lag', 'lead', 'ntile'] __all__ += [ - 'date_format', + 'date_format', 'date_add', 'date_sub', 'add_months', 'months_between', 'year', 'quarter', 'month', 'hour', 'minute', 'second', 'dayofmonth', 'dayofyear', 'weekofyear'] @@ -716,7 +716,7 @@ def date_format(dateCol, format): [Row(date=u'04/08/2015')] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.date_format(dateCol, format)) + return Column(sc._jvm.functions.date_format(_to_java_column(dateCol), format)) @since(1.5) @@ -729,7 +729,7 @@ def year(col): [Row(year=2015)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.year(col)) + return Column(sc._jvm.functions.year(_to_java_column(col))) @since(1.5) @@ -742,7 +742,7 @@ def quarter(col): [Row(quarter=2)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.quarter(col)) + return Column(sc._jvm.functions.quarter(_to_java_column(col))) @since(1.5) @@ -755,7 +755,7 @@ def month(col): [Row(month=4)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.month(col)) + return Column(sc._jvm.functions.month(_to_java_column(col))) @since(1.5) @@ -768,7 +768,7 @@ def dayofmonth(col): [Row(day=8)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.dayofmonth(col)) + return Column(sc._jvm.functions.dayofmonth(_to_java_column(col))) @since(1.5) @@ -781,7 +781,7 @@ def dayofyear(col): [Row(day=98)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.dayofyear(col)) + return Column(sc._jvm.functions.dayofyear(_to_java_column(col))) @since(1.5) @@ -794,7 +794,7 @@ def hour(col): [Row(hour=13)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.hour(col)) + return Column(sc._jvm.functions.hour(_to_java_column(col))) @since(1.5) @@ -807,7 +807,7 @@ def minute(col): [Row(minute=8)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.minute(col)) + return Column(sc._jvm.functions.minute(_to_java_column(col))) @since(1.5) @@ -820,7 +820,7 @@ def second(col): [Row(second=15)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.second(col)) + return Column(sc._jvm.functions.second(_to_java_column(col))) @since(1.5) @@ -829,11 +829,63 @@ def weekofyear(col): Extract the week number of a given date as integer. >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(weekofyear('a').alias('week')).collect() + >>> df.select(weekofyear(df.a).alias('week')).collect() [Row(week=15)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.weekofyear(col)) + return Column(sc._jvm.functions.weekofyear(_to_java_column(col))) + + +@since(1.5) +def date_add(start, days): + """ + Returns the date that is `days` days after `start` + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df.select(date_add(df.d, 1).alias('d')).collect() + [Row(d=datetime.date(2015, 4, 9))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.date_add(_to_java_column(start), days)) + + +@since(1.5) +def date_sub(start, days): + """ + Returns the date that is `days` days before `start` + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df.select(date_sub(df.d, 1).alias('d')).collect() + [Row(d=datetime.date(2015, 4, 7))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.date_sub(_to_java_column(start), days)) + + +@since(1.5) +def add_months(start, months): + """ + Returns the date that is `months` months after `start` + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df.select(add_months(df.d, 1).alias('d')).collect() + [Row(d=datetime.date(2015, 5, 8))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.add_months(_to_java_column(start), months)) + + +@since(1.5) +def months_between(date1, date2): + """ + Returns the number of months between date1 and date2. + + >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd']) + >>> df.select(months_between(df.t, df.d).alias('months')).collect() + [Row(months=3.9495967...)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2))) @since(1.5) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d663f12bc6d0d..6c7c481fab8db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -205,9 +205,12 @@ object FunctionRegistry { expression[Upper]("upper"), // datetime functions + expression[AddMonths]("add_months"), expression[CurrentDate]("current_date"), expression[CurrentTimestamp]("current_timestamp"), + expression[DateAdd]("date_add"), expression[DateFormatClass]("date_format"), + expression[DateSub]("date_sub"), expression[DayOfMonth]("day"), expression[DayOfYear]("dayofyear"), expression[DayOfMonth]("dayofmonth"), @@ -216,6 +219,7 @@ object FunctionRegistry { expression[LastDay]("last_day"), expression[Minute]("minute"), expression[Month]("month"), + expression[MonthsBetween]("months_between"), expression[NextDay]("next_day"), expression[Quarter]("quarter"), expression[Second]("second"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index ecc48986e35d8..603afc4032a37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -47,6 +47,7 @@ object HiveTypeCoercion { Division :: PropagateTypes :: ImplicitTypeCasts :: + DateTimeOperations :: Nil // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. @@ -638,6 +639,27 @@ object HiveTypeCoercion { } } + /** + * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType + * to TimeAdd/TimeSub + */ + object DateTimeOperations extends Rule[LogicalPlan] { + + private val acceptedTypes = Seq(DateType, TimestampType, StringType) + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case Add(l @ CalendarIntervalType(), r) if acceptedTypes.contains(r.dataType) => + Cast(TimeAdd(r, l), r.dataType) + case Add(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) => + Cast(TimeAdd(l, r), l.dataType) + case Subtract(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) => + Cast(TimeSub(l, r), l.dataType) + } + } + /** * Casts types according to the expected input types for [[Expression]]s. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index a5e6249e438d2..9795673ee0664 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import scala.util.Try @@ -63,6 +63,53 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback { } } +/** + * Adds a number of days to startdate. + */ +case class DateAdd(startDate: Expression, days: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = startDate + override def right: Expression = days + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, d: Any): Any = { + start.asInstanceOf[Int] + d.asInstanceOf[Int] + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (sd, d) => { + s"""${ev.primitive} = $sd + $d;""" + }) + } +} + +/** + * Subtracts a number of days to startdate. + */ +case class DateSub(startDate: Expression, days: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + override def left: Expression = startDate + override def right: Expression = days + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, d: Any): Any = { + start.asInstanceOf[Int] - d.asInstanceOf[Int] + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (sd, d) => { + s"""${ev.primitive} = $sd - $d;""" + }) + } +} + case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) @@ -543,3 +590,109 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) override def prettyName: String = "next_day" } + +/** + * Adds an interval to timestamp. + */ +case class TimeAdd(start: Expression, interval: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = start + override def right: Expression = interval + + override def toString: String = s"$left + $right" + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) + + override def dataType: DataType = TimestampType + + override def nullSafeEval(start: Any, interval: Any): Any = { + val itvl = interval.asInstanceOf[CalendarInterval] + DateTimeUtils.timestampAddInterval( + start.asInstanceOf[Long], itvl.months, itvl.microseconds) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (sd, i) => { + s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)""" + }) + } +} + +/** + * Subtracts an interval from timestamp. + */ +case class TimeSub(start: Expression, interval: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = start + override def right: Expression = interval + + override def toString: String = s"$left - $right" + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) + + override def dataType: DataType = TimestampType + + override def nullSafeEval(start: Any, interval: Any): Any = { + val itvl = interval.asInstanceOf[CalendarInterval] + DateTimeUtils.timestampAddInterval( + start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (sd, i) => { + s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)""" + }) + } +} + +/** + * Returns the date that is num_months after start_date. + */ +case class AddMonths(startDate: Expression, numMonths: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = startDate + override def right: Expression = numMonths + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, months: Any): Any = { + DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (sd, m) => { + s"""$dtu.dateAddMonths($sd, $m)""" + }) + } +} + +/** + * Returns number of months between dates date1 and date2. + */ +case class MonthsBetween(date1: Expression, date2: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = date1 + override def right: Expression = date2 + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType) + + override def dataType: DataType = DoubleType + + override def nullSafeEval(t1: Any, t2: Any): Any = { + DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (l, r) => { + s"""$dtu.monthsBetween($l, $r)""" + }) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 93966a503c27c..53abdf6618eac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -45,6 +45,7 @@ object DateTimeUtils { final val to2001 = -11323 // this is year -17999, calculation: 50 * daysIn400Year + final val YearZero = -17999 final val toYearZero = to2001 + 7304850 @transient lazy val defaultTimeZone = TimeZone.getDefault @@ -575,6 +576,144 @@ object DateTimeUtils { } /** + * The number of days for each month (not leap year) + */ + private val monthDays = Array(31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31) + + /** + * Returns the date value for the first day of the given month. + * The month is expressed in months since year zero (17999 BC), starting from 0. + */ + private def firstDayOfMonth(absoluteMonth: Int): Int = { + val absoluteYear = absoluteMonth / 12 + var monthInYear = absoluteMonth - absoluteYear * 12 + var date = getDateFromYear(absoluteYear) + if (monthInYear >= 2 && isLeapYear(absoluteYear + YearZero)) { + date += 1 + } + while (monthInYear > 0) { + date += monthDays(monthInYear - 1) + monthInYear -= 1 + } + date + } + + /** + * Returns the date value for January 1 of the given year. + * The year is expressed in years since year zero (17999 BC), starting from 0. + */ + private def getDateFromYear(absoluteYear: Int): Int = { + val absoluteDays = (absoluteYear * 365 + absoluteYear / 400 - absoluteYear / 100 + + absoluteYear / 4) + absoluteDays - toYearZero + } + + /** + * Add date and year-month interval. + * Returns a date value, expressed in days since 1.1.1970. + */ + def dateAddMonths(days: Int, months: Int): Int = { + val absoluteMonth = (getYear(days) - YearZero) * 12 + getMonth(days) - 1 + months + val currentMonthInYear = absoluteMonth % 12 + val currentYear = absoluteMonth / 12 + val leapDay = if (currentMonthInYear == 1 && isLeapYear(currentYear + YearZero)) 1 else 0 + val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay + + val dayOfMonth = getDayOfMonth(days) + val currentDayInMonth = if (getDayOfMonth(days + 1) == 1 || dayOfMonth >= lastDayOfMonth) { + // last day of the month + lastDayOfMonth + } else { + dayOfMonth + } + firstDayOfMonth(absoluteMonth) + currentDayInMonth - 1 + } + + /** + * Add timestamp and full interval. + * Returns a timestamp value, expressed in microseconds since 1.1.1970 00:00:00. + */ + def timestampAddInterval(start: Long, months: Int, microseconds: Long): Long = { + val days = millisToDays(start / 1000L) + val newDays = dateAddMonths(days, months) + daysToMillis(newDays) * 1000L + start - daysToMillis(days) * 1000L + microseconds + } + + /** + * Returns the last dayInMonth in the month it belongs to. The date is expressed + * in days since 1.1.1970. the return value starts from 1. + */ + private def getLastDayInMonthOfMonth(date: Int): Int = { + var (year, dayInYear) = getYearAndDayInYear(date) + if (isLeapYear(year)) { + if (dayInYear > 31 && dayInYear <= 60) { + return 29 + } else if (dayInYear > 60) { + dayInYear = dayInYear - 1 + } + } + if (dayInYear <= 31) { + 31 + } else if (dayInYear <= 59) { + 28 + } else if (dayInYear <= 90) { + 31 + } else if (dayInYear <= 120) { + 30 + } else if (dayInYear <= 151) { + 31 + } else if (dayInYear <= 181) { + 30 + } else if (dayInYear <= 212) { + 31 + } else if (dayInYear <= 243) { + 31 + } else if (dayInYear <= 273) { + 30 + } else if (dayInYear <= 304) { + 31 + } else if (dayInYear <= 334) { + 30 + } else { + 31 + } + } + + /** + * Returns number of months between time1 and time2. time1 and time2 are expressed in + * microseconds since 1.1.1970. + * + * If time1 and time2 having the same day of month, or both are the last day of month, + * it returns an integer (time under a day will be ignored). + * + * Otherwise, the difference is calculated based on 31 days per month, and rounding to + * 8 digits. + */ + def monthsBetween(time1: Long, time2: Long): Double = { + val millis1 = time1 / 1000L + val millis2 = time2 / 1000L + val date1 = millisToDays(millis1) + val date2 = millisToDays(millis2) + // TODO(davies): get year, month, dayOfMonth from single function + val dayInMonth1 = getDayOfMonth(date1) + val dayInMonth2 = getDayOfMonth(date2) + val months1 = getYear(date1) * 12 + getMonth(date1) + val months2 = getYear(date2) * 12 + getMonth(date2) + + if (dayInMonth1 == dayInMonth2 || (dayInMonth1 == getLastDayInMonthOfMonth(date1) + && dayInMonth2 == getLastDayInMonthOfMonth(date2))) { + return (months1 - months2).toDouble + } + // milliseconds is enough for 8 digits precision on the right side + val timeInDay1 = millis1 - daysToMillis(date1) + val timeInDay2 = millis2 - daysToMillis(date2) + val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY + val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0 + // rounding to 8 digits + math.round(diff * 1e8) / 1e8 + } + + /* * Returns day of week from String. Starting from Thursday, marked as 0. * (Because 1970-01-01 is Thursday). */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 1d9ee5ddf3a5a..70608771dd110 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.catalyst.analysis +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval class HiveTypeCoercionSuite extends PlanTest { @@ -400,6 +403,33 @@ class HiveTypeCoercionSuite extends PlanTest { } } + test("rule for date/timestamp operations") { + val dateTimeOperations = HiveTypeCoercion.DateTimeOperations + val date = Literal(new java.sql.Date(0L)) + val timestamp = Literal(new Timestamp(0L)) + val interval = Literal(new CalendarInterval(0, 0)) + val str = Literal("2015-01-01") + + ruleTest(dateTimeOperations, Add(date, interval), Cast(TimeAdd(date, interval), DateType)) + ruleTest(dateTimeOperations, Add(interval, date), Cast(TimeAdd(date, interval), DateType)) + ruleTest(dateTimeOperations, Add(timestamp, interval), + Cast(TimeAdd(timestamp, interval), TimestampType)) + ruleTest(dateTimeOperations, Add(interval, timestamp), + Cast(TimeAdd(timestamp, interval), TimestampType)) + ruleTest(dateTimeOperations, Add(str, interval), Cast(TimeAdd(str, interval), StringType)) + ruleTest(dateTimeOperations, Add(interval, str), Cast(TimeAdd(str, interval), StringType)) + + ruleTest(dateTimeOperations, Subtract(date, interval), Cast(TimeSub(date, interval), DateType)) + ruleTest(dateTimeOperations, Subtract(timestamp, interval), + Cast(TimeSub(timestamp, interval), TimestampType)) + ruleTest(dateTimeOperations, Subtract(str, interval), Cast(TimeSub(str, interval), StringType)) + + // interval operations should not be effected + ruleTest(dateTimeOperations, Add(interval, interval), Add(interval, interval)) + ruleTest(dateTimeOperations, Subtract(interval, interval), Subtract(interval, interval)) + } + + /** * There are rules that need to not fire before child expressions get resolved. * We use this test to make sure those rules do not fire early. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index e1387f945ffa4..fd1d6c1d25497 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -22,8 +22,8 @@ import java.text.SimpleDateFormat import java.util.Calendar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.sql.types._ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -48,56 +48,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("DayOfYear") { val sdfDay = new SimpleDateFormat("D") - (2002 to 2004).foreach { y => - (0 to 11).foreach { m => - (0 to 5).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), - sdfDay.format(c.getTime).toInt) - } - } - } - (1998 to 2002).foreach { y => - (0 to 11).foreach { m => - (0 to 5).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), - sdfDay.format(c.getTime).toInt) - } - } - } - - (1969 to 1970).foreach { y => - (0 to 11).foreach { m => - (0 to 5).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), - sdfDay.format(c.getTime).toInt) - } - } - } - - (2402 to 2404).foreach { y => - (0 to 11).foreach { m => - (0 to 5).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), - sdfDay.format(c.getTime).toInt) - } - } - } - - (2398 to 2402).foreach { y => - (0 to 11).foreach { m => + (0 to 3).foreach { m => (0 to 5).foreach { i => val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) @@ -117,7 +69,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Year(Cast(Literal(ts), DateType)), 2013) val c = Calendar.getInstance() - (2000 to 2010).foreach { y => + (2000 to 2002).foreach { y => (0 to 11 by 11).foreach { m => c.set(y, m, 28) (0 to 5 * 24).foreach { i => @@ -155,20 +107,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Month(Cast(Literal(ts), DateType)), 11) (2003 to 2004).foreach { y => - (0 to 11).foreach { m => - (0 to 5 * 24).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.HOUR_OF_DAY, i) - checkEvaluation(Month(Literal(new Date(c.getTimeInMillis))), - c.get(Calendar.MONTH) + 1) - } - } - } - - (1999 to 2000).foreach { y => - (0 to 11).foreach { m => - (0 to 5 * 24).foreach { i => + (0 to 3).foreach { m => + (0 to 2 * 24).foreach { i => val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.HOUR_OF_DAY, i) @@ -262,6 +202,112 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("date_add") { + checkEvaluation( + DateAdd(Literal(Date.valueOf("2016-02-28")), Literal(1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29"))) + checkEvaluation( + DateAdd(Literal(Date.valueOf("2016-02-28")), Literal(-365)), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-02-28"))) + checkEvaluation(DateAdd(Literal.create(null, DateType), Literal(1)), null) + checkEvaluation(DateAdd(Literal(Date.valueOf("2016-02-28")), Literal.create(null, IntegerType)), + null) + checkEvaluation(DateAdd(Literal.create(null, DateType), Literal.create(null, IntegerType)), + null) + } + + test("date_sub") { + checkEvaluation( + DateSub(Literal(Date.valueOf("2015-01-01")), Literal(1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2014-12-31"))) + checkEvaluation( + DateSub(Literal(Date.valueOf("2015-01-01")), Literal(-1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-01-02"))) + checkEvaluation(DateSub(Literal.create(null, DateType), Literal(1)), null) + checkEvaluation(DateSub(Literal(Date.valueOf("2016-02-28")), Literal.create(null, IntegerType)), + null) + checkEvaluation(DateSub(Literal.create(null, DateType), Literal.create(null, IntegerType)), + null) + } + + test("time_add") { + checkEvaluation( + TimeAdd(Literal(Timestamp.valueOf("2016-01-29 10:00:00")), + Literal(new CalendarInterval(1, 123000L))), + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-29 10:00:00.123"))) + + checkEvaluation( + TimeAdd(Literal.create(null, TimestampType), Literal(new CalendarInterval(1, 123000L))), + null) + checkEvaluation( + TimeAdd(Literal(Timestamp.valueOf("2016-01-29 10:00:00")), + Literal.create(null, CalendarIntervalType)), + null) + checkEvaluation( + TimeAdd(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), + null) + } + + test("time_sub") { + checkEvaluation( + TimeSub(Literal(Timestamp.valueOf("2016-03-31 10:00:00")), + Literal(new CalendarInterval(1, 0))), + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-29 10:00:00"))) + checkEvaluation( + TimeSub( + Literal(Timestamp.valueOf("2016-03-30 00:00:01")), + Literal(new CalendarInterval(1, 2000000.toLong))), + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-28 23:59:59"))) + + checkEvaluation( + TimeSub(Literal.create(null, TimestampType), Literal(new CalendarInterval(1, 123000L))), + null) + checkEvaluation( + TimeSub(Literal(Timestamp.valueOf("2016-01-29 10:00:00")), + Literal.create(null, CalendarIntervalType)), + null) + checkEvaluation( + TimeSub(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), + null) + } + + test("add_months") { + checkEvaluation(AddMonths(Literal(Date.valueOf("2015-01-30")), Literal(1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-02-28"))) + checkEvaluation(AddMonths(Literal(Date.valueOf("2016-03-30")), Literal(-1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29"))) + checkEvaluation( + AddMonths(Literal(Date.valueOf("2015-01-30")), Literal.create(null, IntegerType)), + null) + checkEvaluation(AddMonths(Literal.create(null, DateType), Literal(1)), null) + checkEvaluation(AddMonths(Literal.create(null, DateType), Literal.create(null, IntegerType)), + null) + } + + test("months_between") { + checkEvaluation( + MonthsBetween(Literal(Timestamp.valueOf("1997-02-28 10:30:00")), + Literal(Timestamp.valueOf("1996-10-30 00:00:00"))), + 3.94959677) + checkEvaluation( + MonthsBetween(Literal(Timestamp.valueOf("2015-01-30 11:52:00")), + Literal(Timestamp.valueOf("2015-01-30 11:50:00"))), + 0.0) + checkEvaluation( + MonthsBetween(Literal(Timestamp.valueOf("2015-01-31 00:00:00")), + Literal(Timestamp.valueOf("2015-03-31 22:00:00"))), + -2.0) + checkEvaluation( + MonthsBetween(Literal(Timestamp.valueOf("2015-03-31 22:00:00")), + Literal(Timestamp.valueOf("2015-02-28 00:00:00"))), + 1.0) + val t = Literal(Timestamp.valueOf("2015-03-31 22:00:00")) + val tnull = Literal.create(null, TimestampType) + checkEvaluation(MonthsBetween(t, tnull), null) + checkEvaluation(MonthsBetween(tnull, t), null) + checkEvaluation(MonthsBetween(tnull, tnull), null) + } + test("last_day") { checkEvaluation(LastDay(Literal(Date.valueOf("2015-02-28"))), Date.valueOf("2015-02-28")) checkEvaluation(LastDay(Literal(Date.valueOf("2015-03-27"))), Date.valueOf("2015-03-31")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index fab9eb9cd4c9f..60d2bcfe13757 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -19,47 +19,48 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.{TimeZone, Calendar} +import java.util.{Calendar, TimeZone} import org.apache.spark.SparkFunSuite import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.catalyst.util.DateTimeUtils._ class DateTimeUtilsSuite extends SparkFunSuite { private[this] def getInUTCDays(timestamp: Long): Int = { val tz = TimeZone.getDefault - ((timestamp + tz.getOffset(timestamp)) / DateTimeUtils.MILLIS_PER_DAY).toInt + ((timestamp + tz.getOffset(timestamp)) / MILLIS_PER_DAY).toInt } test("timestamp and us") { val now = new Timestamp(System.currentTimeMillis()) now.setNanos(1000) - val ns = DateTimeUtils.fromJavaTimestamp(now) + val ns = fromJavaTimestamp(now) assert(ns % 1000000L === 1) - assert(DateTimeUtils.toJavaTimestamp(ns) === now) + assert(toJavaTimestamp(ns) === now) List(-111111111111L, -1L, 0, 1L, 111111111111L).foreach { t => - val ts = DateTimeUtils.toJavaTimestamp(t) - assert(DateTimeUtils.fromJavaTimestamp(ts) === t) - assert(DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJavaTimestamp(ts)) === ts) + val ts = toJavaTimestamp(t) + assert(fromJavaTimestamp(ts) === t) + assert(toJavaTimestamp(fromJavaTimestamp(ts)) === ts) } } test("us and julian day") { - val (d, ns) = DateTimeUtils.toJulianDay(0) - assert(d === DateTimeUtils.JULIAN_DAY_OF_EPOCH) - assert(ns === DateTimeUtils.SECONDS_PER_DAY / 2 * DateTimeUtils.NANOS_PER_SECOND) - assert(DateTimeUtils.fromJulianDay(d, ns) == 0L) + val (d, ns) = toJulianDay(0) + assert(d === JULIAN_DAY_OF_EPOCH) + assert(ns === SECONDS_PER_DAY / 2 * NANOS_PER_SECOND) + assert(fromJulianDay(d, ns) == 0L) val t = new Timestamp(61394778610000L) // (2015, 6, 11, 10, 10, 10, 100) - val (d1, ns1) = DateTimeUtils.toJulianDay(DateTimeUtils.fromJavaTimestamp(t)) - val t2 = DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJulianDay(d1, ns1)) + val (d1, ns1) = toJulianDay(fromJavaTimestamp(t)) + val t2 = toJavaTimestamp(fromJulianDay(d1, ns1)) assert(t.equals(t2)) } test("SPARK-6785: java date conversion before and after epoch") { def checkFromToJavaDate(d1: Date): Unit = { - val d2 = DateTimeUtils.toJavaDate(DateTimeUtils.fromJavaDate(d1)) + val d2 = toJavaDate(fromJavaDate(d1)) assert(d2.toString === d1.toString) } @@ -95,157 +96,156 @@ class DateTimeUtilsSuite extends SparkFunSuite { } test("string to date") { - import DateTimeUtils.millisToDays var c = Calendar.getInstance() c.set(2015, 0, 28, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-01-28")).get === + assert(stringToDate(UTF8String.fromString("2015-01-28")).get === millisToDays(c.getTimeInMillis)) c.set(2015, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015")).get === + assert(stringToDate(UTF8String.fromString("2015")).get === millisToDays(c.getTimeInMillis)) c = Calendar.getInstance() c.set(2015, 2, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03")).get === + assert(stringToDate(UTF8String.fromString("2015-03")).get === millisToDays(c.getTimeInMillis)) c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18 ")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18 ")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18 123142")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18 123142")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18T123123")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18T123123")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18T")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18T")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18X")).isEmpty) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015/03/18")).isEmpty) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015.03.18")).isEmpty) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("20150318")).isEmpty) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-031-8")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015-03-18X")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015/03/18")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015.03.18")).isEmpty) + assert(stringToDate(UTF8String.fromString("20150318")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015-031-8")).isEmpty) } test("string to timestamp") { var c = Calendar.getInstance() c.set(1969, 11, 31, 16, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("1969-12-31 16:00:00")).get === + assert(stringToTimestamp(UTF8String.fromString("1969-12-31 16:00:00")).get === c.getTimeInMillis * 1000) c.set(2015, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015")).get === + assert(stringToTimestamp(UTF8String.fromString("2015")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 ")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18 ")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT-13:53")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17-13:53")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17Z")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17Z")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17Z")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17Z")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17-1:0")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17-1:0")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17-01:00")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17+07:30")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17+07:03")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18 12:03:17.123")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 456) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.456Z")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18 12:03:17.456Z")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123-1:0")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123-01:00")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123+07:30")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123+07:30")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123121+7:30")).get === c.getTimeInMillis * 1000 + 121) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.12312+7:30")).get === c.getTimeInMillis * 1000 + 120) @@ -254,7 +254,7 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MINUTE, 12) c.set(Calendar.SECOND, 15) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("18:12:15")).get === c.getTimeInMillis * 1000) @@ -263,7 +263,7 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MINUTE, 12) c.set(Calendar.SECOND, 15) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("T18:12:15.12312+7:30")).get === c.getTimeInMillis * 1000 + 120) @@ -272,93 +272,130 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MINUTE, 12) c.set(Calendar.SECOND, 15) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("18:12:15.12312+7:30")).get === c.getTimeInMillis * 1000 + 120) c = Calendar.getInstance() c.set(2011, 4, 6, 7, 8, 9) c.set(Calendar.MILLISECOND, 100) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2011-05-06 07:08:09.1000")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("238")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 123142")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T123123")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18X")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015/03/18")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015.03.18")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("20150318")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-031-8")).isEmpty) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp(UTF8String.fromString("238")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-03-18 123142")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T123123")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-03-18X")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015/03/18")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015.03.18")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("20150318")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-031-8")).isEmpty) + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-20:0")).isEmpty) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-0:70")).isEmpty) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-1:0:0")).isEmpty) } test("hours") { val c = Calendar.getInstance() c.set(2015, 2, 18, 13, 2, 11) - assert(DateTimeUtils.getHours(c.getTimeInMillis * 1000) === 13) + assert(getHours(c.getTimeInMillis * 1000) === 13) c.set(2015, 12, 8, 2, 7, 9) - assert(DateTimeUtils.getHours(c.getTimeInMillis * 1000) === 2) + assert(getHours(c.getTimeInMillis * 1000) === 2) } test("minutes") { val c = Calendar.getInstance() c.set(2015, 2, 18, 13, 2, 11) - assert(DateTimeUtils.getMinutes(c.getTimeInMillis * 1000) === 2) + assert(getMinutes(c.getTimeInMillis * 1000) === 2) c.set(2015, 2, 8, 2, 7, 9) - assert(DateTimeUtils.getMinutes(c.getTimeInMillis * 1000) === 7) + assert(getMinutes(c.getTimeInMillis * 1000) === 7) } test("seconds") { val c = Calendar.getInstance() c.set(2015, 2, 18, 13, 2, 11) - assert(DateTimeUtils.getSeconds(c.getTimeInMillis * 1000) === 11) + assert(getSeconds(c.getTimeInMillis * 1000) === 11) c.set(2015, 2, 8, 2, 7, 9) - assert(DateTimeUtils.getSeconds(c.getTimeInMillis * 1000) === 9) + assert(getSeconds(c.getTimeInMillis * 1000) === 9) } test("get day in year") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getDayInYear(getInUTCDays(c.getTimeInMillis)) === 77) + assert(getDayInYear(getInUTCDays(c.getTimeInMillis)) === 77) c.set(2012, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getDayInYear(getInUTCDays(c.getTimeInMillis)) === 78) + assert(getDayInYear(getInUTCDays(c.getTimeInMillis)) === 78) } test("get year") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getYear(getInUTCDays(c.getTimeInMillis)) === 2015) + assert(getYear(getInUTCDays(c.getTimeInMillis)) === 2015) c.set(2012, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getYear(getInUTCDays(c.getTimeInMillis)) === 2012) + assert(getYear(getInUTCDays(c.getTimeInMillis)) === 2012) } test("get quarter") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getQuarter(getInUTCDays(c.getTimeInMillis)) === 1) + assert(getQuarter(getInUTCDays(c.getTimeInMillis)) === 1) c.set(2012, 11, 18, 0, 0, 0) - assert(DateTimeUtils.getQuarter(getInUTCDays(c.getTimeInMillis)) === 4) + assert(getQuarter(getInUTCDays(c.getTimeInMillis)) === 4) } test("get month") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getMonth(getInUTCDays(c.getTimeInMillis)) === 3) + assert(getMonth(getInUTCDays(c.getTimeInMillis)) === 3) c.set(2012, 11, 18, 0, 0, 0) - assert(DateTimeUtils.getMonth(getInUTCDays(c.getTimeInMillis)) === 12) + assert(getMonth(getInUTCDays(c.getTimeInMillis)) === 12) } test("get day of month") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 18) + assert(getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 18) c.set(2012, 11, 24, 0, 0, 0) - assert(DateTimeUtils.getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 24) + assert(getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 24) + } + + test("date add months") { + val c1 = Calendar.getInstance() + c1.set(1997, 1, 28, 10, 30, 0) + val days1 = millisToDays(c1.getTimeInMillis) + val c2 = Calendar.getInstance() + c2.set(2000, 1, 29) + assert(dateAddMonths(days1, 36) === millisToDays(c2.getTimeInMillis)) + c2.set(1996, 0, 31) + assert(dateAddMonths(days1, -13) === millisToDays(c2.getTimeInMillis)) + } + + test("timestamp add months") { + val c1 = Calendar.getInstance() + c1.set(1997, 1, 28, 10, 30, 0) + c1.set(Calendar.MILLISECOND, 0) + val ts1 = c1.getTimeInMillis * 1000L + val c2 = Calendar.getInstance() + c2.set(2000, 1, 29, 10, 30, 0) + c2.set(Calendar.MILLISECOND, 123) + val ts2 = c2.getTimeInMillis * 1000L + assert(timestampAddInterval(ts1, 36, 123000) === ts2) + } + + test("monthsBetween") { + val c1 = Calendar.getInstance() + c1.set(1997, 1, 28, 10, 30, 0) + val c2 = Calendar.getInstance() + c2.set(1996, 9, 30, 0, 0, 0) + assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 3.94959677) + c2.set(2000, 1, 28, 0, 0, 0) + assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) + c2.set(2000, 1, 29, 0, 0, 0) + assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) + c2.set(1996, 2, 31, 0, 0, 0) + assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 11) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3f440e062eb96..168894d66117d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1927,6 +1927,14 @@ object functions { // DateTime functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Returns the date that is numMonths after startDate. + * @group datetime_funcs + * @since 1.5.0 + */ + def add_months(startDate: Column, numMonths: Int): Column = + AddMonths(startDate.expr, Literal(numMonths)) + /** * Converts a date/timestamp/string to a value of string in the format specified by the date * format given by the second argument. @@ -1959,6 +1967,20 @@ object functions { def date_format(dateColumnName: String, format: String): Column = date_format(Column(dateColumnName), format) + /** + * Returns the date that is `days` days after `start` + * @group datetime_funcs + * @since 1.5.0 + */ + def date_add(start: Column, days: Int): Column = DateAdd(start.expr, Literal(days)) + + /** + * Returns the date that is `days` days before `start` + * @group datetime_funcs + * @since 1.5.0 + */ + def date_sub(start: Column, days: Int): Column = DateSub(start.expr, Literal(days)) + /** * Extracts the year as an integer from a given date/timestamp/string. * @group datetime_funcs @@ -2067,6 +2089,13 @@ object functions { */ def minute(columnName: String): Column = minute(Column(columnName)) + /* + * Returns number of months between dates `date1` and `date2`. + * @group datetime_funcs + * @since 1.5.0 + */ + def months_between(date1: Column, date2: Column): Column = MonthsBetween(date1.expr, date2.expr) + /** * Given a date column, returns the first date which is later than the value of the date column * that is on the specified day of the week. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index df4cb57ac5b21..b7267c413165a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -22,6 +22,7 @@ import java.text.SimpleDateFormat import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.unsafe.types.CalendarInterval class DateFunctionsSuite extends QueryTest { private lazy val ctx = org.apache.spark.sql.test.TestSQLContext @@ -206,6 +207,122 @@ class DateFunctionsSuite extends QueryTest { Row(15, 15, 15)) } + test("function date_add") { + val st1 = "2015-06-01 12:34:56" + val st2 = "2015-06-02 12:34:56" + val t1 = Timestamp.valueOf(st1) + val t2 = Timestamp.valueOf(st2) + val s1 = "2015-06-01" + val s2 = "2015-06-02" + val d1 = Date.valueOf(s1) + val d2 = Date.valueOf(s2) + val df = Seq((t1, d1, s1, st1), (t2, d2, s2, st2)).toDF("t", "d", "s", "ss") + checkAnswer( + df.select(date_add(col("d"), 1)), + Seq(Row(Date.valueOf("2015-06-02")), Row(Date.valueOf("2015-06-03")))) + checkAnswer( + df.select(date_add(col("t"), 3)), + Seq(Row(Date.valueOf("2015-06-04")), Row(Date.valueOf("2015-06-05")))) + checkAnswer( + df.select(date_add(col("s"), 5)), + Seq(Row(Date.valueOf("2015-06-06")), Row(Date.valueOf("2015-06-07")))) + checkAnswer( + df.select(date_add(col("ss"), 7)), + Seq(Row(Date.valueOf("2015-06-08")), Row(Date.valueOf("2015-06-09")))) + + checkAnswer(df.selectExpr("DATE_ADD(null, 1)"), Seq(Row(null), Row(null))) + checkAnswer( + df.selectExpr("""DATE_ADD(d, 1)"""), + Seq(Row(Date.valueOf("2015-06-02")), Row(Date.valueOf("2015-06-03")))) + } + + test("function date_sub") { + val st1 = "2015-06-01 12:34:56" + val st2 = "2015-06-02 12:34:56" + val t1 = Timestamp.valueOf(st1) + val t2 = Timestamp.valueOf(st2) + val s1 = "2015-06-01" + val s2 = "2015-06-02" + val d1 = Date.valueOf(s1) + val d2 = Date.valueOf(s2) + val df = Seq((t1, d1, s1, st1), (t2, d2, s2, st2)).toDF("t", "d", "s", "ss") + checkAnswer( + df.select(date_sub(col("d"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("t"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("s"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("ss"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(lit(null), 1)).limit(1), Row(null)) + + checkAnswer(df.selectExpr("""DATE_SUB(d, null)"""), Seq(Row(null), Row(null))) + checkAnswer( + df.selectExpr("""DATE_SUB(d, 1)"""), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + } + + test("time_add") { + val t1 = Timestamp.valueOf("2015-07-31 23:59:59") + val t2 = Timestamp.valueOf("2015-12-31 00:00:00") + val d1 = Date.valueOf("2015-07-31") + val d2 = Date.valueOf("2015-12-31") + val i = new CalendarInterval(2, 2000000L) + val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d") + checkAnswer( + df.selectExpr(s"d + $i"), + Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2016-02-29")))) + checkAnswer( + df.selectExpr(s"t + $i"), + Seq(Row(Timestamp.valueOf("2015-10-01 00:00:01")), + Row(Timestamp.valueOf("2016-02-29 00:00:02")))) + } + + test("time_sub") { + val t1 = Timestamp.valueOf("2015-10-01 00:00:01") + val t2 = Timestamp.valueOf("2016-02-29 00:00:02") + val d1 = Date.valueOf("2015-09-30") + val d2 = Date.valueOf("2016-02-29") + val i = new CalendarInterval(2, 2000000L) + val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d") + checkAnswer( + df.selectExpr(s"d - $i"), + Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-12-30")))) + checkAnswer( + df.selectExpr(s"t - $i"), + Seq(Row(Timestamp.valueOf("2015-07-31 23:59:59")), + Row(Timestamp.valueOf("2015-12-31 00:00:00")))) + } + + test("function add_months") { + val d1 = Date.valueOf("2015-08-31") + val d2 = Date.valueOf("2015-02-28") + val df = Seq((1, d1), (2, d2)).toDF("n", "d") + checkAnswer( + df.select(add_months(col("d"), 1)), + Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2015-03-31")))) + checkAnswer( + df.selectExpr("add_months(d, -1)"), + Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-01-31")))) + } + + test("function months_between") { + val d1 = Date.valueOf("2015-07-31") + val d2 = Date.valueOf("2015-02-16") + val t1 = Timestamp.valueOf("2014-09-30 23:30:00") + val t2 = Timestamp.valueOf("2015-09-16 12:00:00") + val s1 = "2014-09-15 11:30:00" + val s2 = "2015-10-01 00:00:00" + val df = Seq((t1, d1, s1), (t2, d2, s2)).toDF("t", "d", "s") + checkAnswer(df.select(months_between(col("t"), col("d"))), Seq(Row(-10.0), Row(7.0))) + checkAnswer(df.selectExpr("months_between(t, s)"), Seq(Row(0.5), Row(-0.5))) + } + test("function last_day") { val df1 = Seq((1, "2015-07-23"), (2, "2015-07-24")).toDF("i", "d") val df2 = Seq((1, "2015-07-23 00:11:22"), (2, "2015-07-24 11:22:33")).toDF("i", "t") From 89cda69ecd5ef942a68ad13fc4e1f4184010f087 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 30 Jul 2015 14:08:59 -0700 Subject: [PATCH 36/50] [SPARK-9454] Change LDASuite tests to use vector comparisons jkbradley Changes the current hacky string-comparison for vector compares. Author: Feynman Liang Closes #7775 from feynmanliang/SPARK-9454-ldasuite-vector-compare and squashes the following commits: bd91a82 [Feynman Liang] Remove println 905c76e [Feynman Liang] Fix string compare in distributed EM 2f24c13 [Feynman Liang] Improve LDASuite tests --- .../spark/mllib/clustering/LDASuite.scala | 33 ++++++++----------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index d74482d3a7598..c43e1e575c09c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -83,21 +83,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.topicsMatrix === localModel.topicsMatrix) // Check: topic summaries - // The odd decimal formatting and sorting is a hack to do a robust comparison. - val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) => - // cut values to 3 digits after the decimal place - terms.zip(termWeights).map { case (term, weight) => - ("%.3f".format(weight).toDouble, term.toInt) - } - }.sortBy(_.mkString("")) - val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) => - // cut values to 3 digits after the decimal place - terms.zip(termWeights).map { case (term, weight) => - ("%.3f".format(weight).toDouble, term.toInt) - } - }.sortBy(_.mkString("")) - roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) => - assert(t1 === t2) + val topicSummary = model.describeTopics().map { case (terms, termWeights) => + Vectors.sparse(tinyVocabSize, terms, termWeights) + }.sortBy(_.toString) + val localTopicSummary = localModel.describeTopics().map { case (terms, termWeights) => + Vectors.sparse(tinyVocabSize, terms, termWeights) + }.sortBy(_.toString) + topicSummary.zip(localTopicSummary).foreach { case (topics, topicsLocal) => + assert(topics ~== topicsLocal absTol 0.01) } // Check: per-doc topic distributions @@ -197,10 +190,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { // verify the result, Note this generate the identical result as // [[https://github.com/Blei-Lab/onlineldavb]] - val topic1 = op.getLambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ") - val topic2 = op.getLambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ") - assert("1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950" == topic1) - assert("0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050" == topic2) + val topic1: Vector = Vectors.fromBreeze(op.getLambda(0, ::).t) + val topic2: Vector = Vectors.fromBreeze(op.getLambda(1, ::).t) + val expectedTopic1 = Vectors.dense(1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950) + val expectedTopic2 = Vectors.dense(0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050) + assert(topic1 ~== expectedTopic1 absTol 0.01) + assert(topic2 ~== expectedTopic2 absTol 0.01) } test("OnlineLDAOptimizer with toy data") { From 0dbd6963d589a8f6ad344273f3da7df680ada515 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 30 Jul 2015 15:39:46 -0700 Subject: [PATCH 37/50] [SPARK-9479] [STREAMING] [TESTS] Fix ReceiverTrackerSuite failure for maven build and other potential test failures in Streaming See https://issues.apache.org/jira/browse/SPARK-9479 for the failure cause. The PR includes the following changes: 1. Make ReceiverTrackerSuite create StreamingContext in the test body. 2. Fix places that don't stop StreamingContext. I verified no SparkContext was stopped in the shutdown hook locally after this fix. 3. Fix an issue that `ReceiverTracker.endpoint` may be null. 4. Make sure stopping SparkContext in non-main thread won't fail other tests. Author: zsxwing Closes #7797 from zsxwing/fix-ReceiverTrackerSuite and squashes the following commits: 3a4bb98 [zsxwing] Fix another potential NPE d7497df [zsxwing] Fix ReceiverTrackerSuite; make sure StreamingContext in tests is closed --- .../StreamingLogisticRegressionSuite.scala | 21 +++++-- .../clustering/StreamingKMeansSuite.scala | 17 ++++-- .../StreamingLinearRegressionSuite.scala | 21 +++++-- .../streaming/scheduler/ReceiverTracker.scala | 12 +++- .../apache/spark/streaming/JavaAPISuite.java | 1 + .../streaming/BasicOperationsSuite.scala | 58 ++++++++++--------- .../spark/streaming/InputStreamsSuite.scala | 38 ++++++------ .../spark/streaming/MasterFailureTest.scala | 8 ++- .../streaming/StreamingContextSuite.scala | 22 +++++-- .../streaming/StreamingListenerSuite.scala | 13 ++++- .../scheduler/ReceiverTrackerSuite.scala | 56 +++++++++--------- .../StreamingJobProgressListenerSuite.scala | 19 ++++-- 12 files changed, 183 insertions(+), 103 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index fd653296c9d97..d7b291d5a6330 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -24,13 +24,22 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 30000 + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + // Test if we can accurately learn B for Y = logistic(BX) on streaming data test("parameter accuracy") { @@ -50,7 +59,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase } // apply model training to input stream - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -84,7 +93,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase // apply model training to input stream, storing the intermediate results // (we add a count to ensure the result is a DStream) - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - B))) inputDStream.count() @@ -118,7 +127,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase } // apply model predictions to test stream - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) @@ -147,7 +156,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase } // train and predict - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) @@ -167,7 +176,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase .setNumIterations(10) val numBatches = 10 val emptyInput = Seq.empty[Seq[LabeledPoint]] - val ssc = setupStreams(emptyInput, + ssc = setupStreams(emptyInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index ac01622b8a089..3645d29dccdb2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.random.XORShiftRandom @@ -28,6 +28,15 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { override def maxWaitTimeMillis: Int = 30000 + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + test("accuracy for single center and equivalence to grand average") { // set parameters val numBatches = 10 @@ -46,7 +55,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -82,7 +91,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { kMeans.trainOn(inputDStream) inputDStream.count() }) @@ -114,7 +123,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0))) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { kMeans.trainOn(inputDStream) inputDStream.count() }) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index a2a4c5f6b8b70..34c07ed170816 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -22,14 +22,23 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.TestSuiteBase class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 20000 + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + // Assert that two values are equal within tolerance epsilon def assertEqual(v1: Double, v2: Double, epsilon: Double) { def errorMessage = v1.toString + " did not equal " + v2.toString @@ -62,7 +71,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { } // apply model training to input stream - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -98,7 +107,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { // apply model training to input stream, storing the intermediate results // (we add a count to ensure the result is a DStream) - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0))) inputDStream.count() @@ -129,7 +138,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { } // apply model predictions to test stream - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) // collect the output as (true, estimated) tuples @@ -156,7 +165,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { } // train and predict - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) @@ -177,7 +186,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { val numBatches = 10 val nPoints = 100 val emptyInput = Seq.empty[Seq[LabeledPoint]] - val ssc = setupStreams(emptyInput, + ssc = setupStreams(emptyInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 6270137951b5a..e076fb5ea174b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -223,7 +223,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Signal the receivers to delete old block data if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { logInfo(s"Cleanup old received batch data: $cleanupThreshTime") - endpoint.send(CleanupOldBlocks(cleanupThreshTime)) + synchronized { + if (isTrackerStarted) { + endpoint.send(CleanupOldBlocks(cleanupThreshTime)) + } + } } } @@ -285,8 +289,10 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } /** Update a receiver's maximum ingestion rate */ - def sendRateUpdate(streamUID: Int, newRate: Long): Unit = { - endpoint.send(UpdateReceiverRateLimit(streamUID, newRate)) + def sendRateUpdate(streamUID: Int, newRate: Long): Unit = synchronized { + if (isTrackerStarted) { + endpoint.send(UpdateReceiverRateLimit(streamUID, newRate)) + } } /** Add new blocks for the given stream */ diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index a34f23475804a..e0718f73aa13f 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -1735,6 +1735,7 @@ public Integer call(String s) throws Exception { @SuppressWarnings("unchecked") @Test public void testContextGetOrCreate() throws InterruptedException { + ssc.stop(); final SparkConf conf = new SparkConf() .setMaster("local[2]") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 08faeaa58f419..255376807c957 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -81,39 +81,41 @@ class BasicOperationsSuite extends TestSuiteBase { test("repartition (more partitions)") { val input = Seq(1 to 100, 101 to 200, 201 to 300) val operation = (r: DStream[Int]) => r.repartition(5) - val ssc = setupStreams(input, operation, 2) - val output = runStreamsWithPartitions(ssc, 3, 3) - assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) - - assert(first.size === 5) - assert(second.size === 5) - assert(third.size === 5) - - assert(first.flatten.toSet.equals((1 to 100).toSet) ) - assert(second.flatten.toSet.equals((101 to 200).toSet)) - assert(third.flatten.toSet.equals((201 to 300).toSet)) + withStreamingContext(setupStreams(input, operation, 2)) { ssc => + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 5) + assert(second.size === 5) + assert(third.size === 5) + + assert(first.flatten.toSet.equals((1 to 100).toSet)) + assert(second.flatten.toSet.equals((101 to 200).toSet)) + assert(third.flatten.toSet.equals((201 to 300).toSet)) + } } test("repartition (fewer partitions)") { val input = Seq(1 to 100, 101 to 200, 201 to 300) val operation = (r: DStream[Int]) => r.repartition(2) - val ssc = setupStreams(input, operation, 5) - val output = runStreamsWithPartitions(ssc, 3, 3) - assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) - - assert(first.size === 2) - assert(second.size === 2) - assert(third.size === 2) - - assert(first.flatten.toSet.equals((1 to 100).toSet)) - assert(second.flatten.toSet.equals( (101 to 200).toSet)) - assert(third.flatten.toSet.equals((201 to 300).toSet)) + withStreamingContext(setupStreams(input, operation, 5)) { ssc => + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 2) + assert(second.size === 2) + assert(third.size === 2) + + assert(first.flatten.toSet.equals((1 to 100).toSet)) + assert(second.flatten.toSet.equals((101 to 200).toSet)) + assert(third.flatten.toSet.equals((201 to 300).toSet)) + } } test("groupByKey") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index b74d67c63a788..ec2852d9a0206 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -325,27 +325,31 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("test track the number of input stream") { - val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => - class TestInputDStream extends InputDStream[String](ssc) { - def start() { } - def stop() { } - def compute(validTime: Time): Option[RDD[String]] = None - } + class TestInputDStream extends InputDStream[String](ssc) { + def start() {} - class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) { - def getReceiver: Receiver[String] = null - } + def stop() {} + + def compute(validTime: Time): Option[RDD[String]] = None + } + + class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) { + def getReceiver: Receiver[String] = null + } - // Register input streams - val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream) - val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream) + // Register input streams + val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream) + val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream) - assert(ssc.graph.getInputStreams().length == receiverInputStreams.length + inputStreams.length) - assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length) - assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams) - assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i)) - assert(receiverInputStreams.map(_.id) === Array(0, 1)) + assert(ssc.graph.getInputStreams().length == + receiverInputStreams.length + inputStreams.length) + assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length) + assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams) + assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i)) + assert(receiverInputStreams.map(_.id) === Array(0, 1)) + } } def testFileStream(newFilesOnly: Boolean) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index 6e9d4431090a2..0e64b57e0ffd8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -244,7 +244,13 @@ object MasterFailureTest extends Logging { } catch { case e: Exception => logError("Error running streaming context", e) } - if (killingThread.isAlive) killingThread.interrupt() + if (killingThread.isAlive) { + killingThread.interrupt() + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is + // stopped before running the next test. Otherwise, it's possible that we set SparkEnv.env + // to null after the next test creates the new SparkContext and fail the test. + killingThread.join() + } ssc.stop() logInfo("Has been killed = " + killed) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 4bba9691f8aa5..84a5fbb3d95eb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -120,7 +120,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) myConf.set("spark.streaming.checkpoint.directory", checkpointDirectory) - val ssc = new StreamingContext(myConf, batchDuration) + ssc = new StreamingContext(myConf, batchDuration) assert(ssc.checkpointDir != null) } @@ -369,16 +369,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo } assert(exception.isInstanceOf[TestFailedDueToTimeoutException], "Did not wait for stop") + var t: Thread = null // test whether wait exits if context is stopped failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown - new Thread() { + t = new Thread() { override def run() { Thread.sleep(500) ssc.stop() } - }.start() + } + t.start() ssc.awaitTermination() } + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped + // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after + // the next test creates the new SparkContext and fail the test. + t.join() } test("awaitTermination after stop") { @@ -430,16 +436,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.awaitTerminationOrTimeout(500) === false) } + var t: Thread = null // test whether awaitTerminationOrTimeout() return true if context is stopped failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown - new Thread() { + t = new Thread() { override def run() { Thread.sleep(500) ssc.stop() } - }.start() + } + t.start() assert(ssc.awaitTerminationOrTimeout(10000) === true) } + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped + // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after + // the next test creates the new SparkContext and fail the test. + t.join() } test("getOrCreate") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 4bc1dd4a30fc4..d840c349bbbc4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -36,13 +36,22 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + // To make sure that the processing start and end times in collected // information are different for successive batches override def batchDuration: Duration = Milliseconds(100) override def actuallyWait: Boolean = true test("batch info reporting") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val collector = new BatchInfoCollector ssc.addStreamingListener(collector) runStreams(ssc, input.size, input.size) @@ -107,7 +116,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { } test("receiver info reporting") { - val ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) + ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) inputStream.foreachRDD(_.count) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index aff8b53f752fa..afad5f16dbc71 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -29,36 +29,40 @@ import org.apache.spark.storage.StorageLevel /** Testsuite for receiver scheduling */ class ReceiverTrackerSuite extends TestSuiteBase { val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test") - val ssc = new StreamingContext(sparkConf, Milliseconds(100)) - ignore("Receiver tracker - propagates rate limit") { - object ReceiverStartedWaiter extends StreamingListener { - @volatile - var started = false + test("Receiver tracker - propagates rate limit") { + withStreamingContext(new StreamingContext(sparkConf, Milliseconds(100))) { ssc => + object ReceiverStartedWaiter extends StreamingListener { + @volatile + var started = false - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { - started = true + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { + started = true + } } - } - - ssc.addStreamingListener(ReceiverStartedWaiter) - ssc.scheduler.listenerBus.start(ssc.sc) - SingletonTestRateReceiver.reset() - - val newRateLimit = 100L - val inputDStream = new RateLimitInputDStream(ssc) - val tracker = new ReceiverTracker(ssc) - tracker.start() - // we wait until the Receiver has registered with the tracker, - // otherwise our rate update is lost - eventually(timeout(5 seconds)) { - assert(ReceiverStartedWaiter.started) - } - tracker.sendRateUpdate(inputDStream.id, newRateLimit) - // this is an async message, we need to wait a bit for it to be processed - eventually(timeout(3 seconds)) { - assert(inputDStream.getCurrentRateLimit.get === newRateLimit) + ssc.addStreamingListener(ReceiverStartedWaiter) + ssc.scheduler.listenerBus.start(ssc.sc) + SingletonTestRateReceiver.reset() + + val newRateLimit = 100L + val inputDStream = new RateLimitInputDStream(ssc) + val tracker = new ReceiverTracker(ssc) + tracker.start() + try { + // we wait until the Receiver has registered with the tracker, + // otherwise our rate update is lost + eventually(timeout(5 seconds)) { + assert(ReceiverStartedWaiter.started) + } + tracker.sendRateUpdate(inputDStream.id, newRateLimit) + // this is an async message, we need to wait a bit for it to be processed + eventually(timeout(3 seconds)) { + assert(inputDStream.getCurrentRateLimit.get === newRateLimit) + } + } finally { + tracker.stop(false) + } } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 0891309f956d2..995f1197ccdfd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -22,15 +22,24 @@ import java.util.Properties import org.scalatest.Matchers import org.apache.spark.scheduler.SparkListenerJobStart +import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.{Duration, Time, Milliseconds, TestSuiteBase} class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + private def createJobStart( batchTime: Time, outputOpId: Int, jobId: Int): SparkListenerJobStart = { val properties = new Properties() @@ -46,7 +55,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { test("onBatchSubmitted, onBatchStarted, onBatchCompleted, " + "onReceiverStarted, onReceiverError, onReceiverStopped") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) val streamIdToInputInfo = Map( @@ -141,7 +150,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("Remove the old completed batches when exceeding the limit") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) val listener = new StreamingJobProgressListener(ssc) @@ -158,7 +167,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("out-of-order onJobStart and onBatchXXX") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) val listener = new StreamingJobProgressListener(ssc) @@ -209,7 +218,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("detect memory leak") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) From 7f7a319c4ce07f07a6bd68100cf0a4f1da66269e Mon Sep 17 00:00:00 2001 From: martinzapletal Date: Thu, 30 Jul 2015 15:57:14 -0700 Subject: [PATCH 38/50] [SPARK-8671] [ML] Added isotonic regression to the pipeline API. Author: martinzapletal Closes #7517 from zapletal-martin/SPARK-8671-isotonic-regression-api and squashes the following commits: 8c435c1 [martinzapletal] Review https://github.com/apache/spark/pull/7517 feedback update. bebbb86 [martinzapletal] Merge remote-tracking branch 'upstream/master' into SPARK-8671-isotonic-regression-api b68efc0 [martinzapletal] Added tests for param validation. 07c12bd [martinzapletal] Comments and refactoring. 834fcf7 [martinzapletal] Merge remote-tracking branch 'upstream/master' into SPARK-8671-isotonic-regression-api b611fee [martinzapletal] SPARK-8671. Added first version of isotonic regression to pipeline API --- .../ml/regression/IsotonicRegression.scala | 144 +++++++++++++++++ .../regression/IsotonicRegressionSuite.scala | 148 ++++++++++++++++++ 2 files changed, 292 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala new file mode 100644 index 0000000000000..4ece8cf8cf0b6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.param.{Param, ParamMap, BooleanParam} +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} +import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.{DoubleType, DataType} +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.storage.StorageLevel + +/** + * Params for isotonic regression. + */ +private[regression] trait IsotonicRegressionParams extends PredictorParams { + + /** + * Param for weight column name. + * TODO: Move weightCol to sharedParams. + * + * @group param + */ + final val weightCol: Param[String] = + new Param[String](this, "weightCol", "weight column name") + + /** @group getParam */ + final def getWeightCol: String = $(weightCol) + + /** + * Param for isotonic parameter. + * Isotonic (increasing) or antitonic (decreasing) sequence. + * @group param + */ + final val isotonic: BooleanParam = + new BooleanParam(this, "isotonic", "isotonic (increasing) or antitonic (decreasing) sequence") + + /** @group getParam */ + final def getIsotonicParam: Boolean = $(isotonic) +} + +/** + * :: Experimental :: + * Isotonic regression. + * + * Currently implemented using parallelized pool adjacent violators algorithm. + * Only univariate (single feature) algorithm supported. + * + * Uses [[org.apache.spark.mllib.regression.IsotonicRegression]]. + */ +@Experimental +class IsotonicRegression(override val uid: String) + extends Regressor[Double, IsotonicRegression, IsotonicRegressionModel] + with IsotonicRegressionParams { + + def this() = this(Identifiable.randomUID("isoReg")) + + /** + * Set the isotonic parameter. + * Default is true. + * @group setParam + */ + def setIsotonicParam(value: Boolean): this.type = set(isotonic, value) + setDefault(isotonic -> true) + + /** + * Set weight column param. + * Default is weight. + * @group setParam + */ + def setWeightParam(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "weight") + + override private[ml] def featuresDataType: DataType = DoubleType + + override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) + + private[this] def extractWeightedLabeledPoints( + dataset: DataFrame): RDD[(Double, Double, Double)] = { + + dataset.select($(labelCol), $(featuresCol), $(weightCol)) + .map { case Row(label: Double, features: Double, weights: Double) => + (label, features, weights) + } + } + + override protected def train(dataset: DataFrame): IsotonicRegressionModel = { + SchemaUtils.checkColumnType(dataset.schema, $(weightCol), DoubleType) + // Extract columns from data. If dataset is persisted, do not persist oldDataset. + val instances = extractWeightedLabeledPoints(dataset) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + + val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic)) + val parentModel = isotonicRegression.run(instances) + + new IsotonicRegressionModel(uid, parentModel) + } +} + +/** + * :: Experimental :: + * Model fitted by IsotonicRegression. + * Predicts using a piecewise linear function. + * + * For detailed rules see [[org.apache.spark.mllib.regression.IsotonicRegressionModel.predict()]]. + * + * @param parentModel A [[org.apache.spark.mllib.regression.IsotonicRegressionModel]] + * model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]]. + */ +class IsotonicRegressionModel private[ml] ( + override val uid: String, + private[ml] val parentModel: MLlibIsotonicRegressionModel) + extends RegressionModel[Double, IsotonicRegressionModel] + with IsotonicRegressionParams { + + override def featuresDataType: DataType = DoubleType + + override protected def predict(features: Double): Double = { + parentModel.predict(features) + } + + override def copy(extra: ParamMap): IsotonicRegressionModel = { + copyValues(new IsotonicRegressionModel(uid, parentModel), extra) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala new file mode 100644 index 0000000000000..66e4b170bae80 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} +import org.apache.spark.sql.{DataFrame, Row} + +class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + private val schema = StructType( + Array( + StructField("label", DoubleType), + StructField("features", DoubleType), + StructField("weight", DoubleType))) + + private val predictionSchema = StructType(Array(StructField("features", DoubleType))) + + private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { + val data = Seq.tabulate(labels.size)(i => Row(labels(i), i.toDouble, 1d)) + val parallelData = sc.parallelize(data) + + sqlContext.createDataFrame(parallelData, schema) + } + + private def generatePredictionInput(features: Seq[Double]): DataFrame = { + val data = Seq.tabulate(features.size)(i => Row(features(i))) + + val parallelData = sc.parallelize(data) + sqlContext.createDataFrame(parallelData, predictionSchema) + } + + test("isotonic regression predictions") { + val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18)) + val trainer = new IsotonicRegression().setIsotonicParam(true) + + val model = trainer.fit(dataset) + + val predictions = model + .transform(dataset) + .select("prediction").map { + case Row(pred) => pred + }.collect() + + assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18)) + + assert(model.parentModel.boundaries === Array(0, 1, 3, 4, 5, 6, 7, 8)) + assert(model.parentModel.predictions === Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)) + assert(model.parentModel.isotonic) + } + + test("antitonic regression predictions") { + val dataset = generateIsotonicInput(Seq(7, 5, 3, 5, 1)) + val trainer = new IsotonicRegression().setIsotonicParam(false) + + val model = trainer.fit(dataset) + val features = generatePredictionInput(Seq(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0)) + + val predictions = model + .transform(features) + .select("prediction").map { + case Row(pred) => pred + }.collect() + + assert(predictions === Array(7, 7, 6, 5.5, 5, 4, 1)) + } + + test("params validation") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + val ir = new IsotonicRegression + ParamsSuite.checkParams(ir) + val model = ir.fit(dataset) + ParamsSuite.checkParams(model) + } + + test("default params") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + val ir = new IsotonicRegression() + assert(ir.getLabelCol === "label") + assert(ir.getFeaturesCol === "features") + assert(ir.getWeightCol === "weight") + assert(ir.getPredictionCol === "prediction") + assert(ir.getIsotonicParam === true) + + val model = ir.fit(dataset) + model.transform(dataset) + .select("label", "features", "prediction", "weight") + .collect() + + assert(model.getLabelCol === "label") + assert(model.getFeaturesCol === "features") + assert(model.getWeightCol === "weight") + assert(model.getPredictionCol === "prediction") + assert(model.getIsotonicParam === true) + assert(model.hasParent) + } + + test("set parameters") { + val isotonicRegression = new IsotonicRegression() + .setIsotonicParam(false) + .setWeightParam("w") + .setFeaturesCol("f") + .setLabelCol("l") + .setPredictionCol("p") + + assert(isotonicRegression.getIsotonicParam === false) + assert(isotonicRegression.getWeightCol === "w") + assert(isotonicRegression.getFeaturesCol === "f") + assert(isotonicRegression.getLabelCol === "l") + assert(isotonicRegression.getPredictionCol === "p") + } + + test("missing column") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + + intercept[IllegalArgumentException] { + new IsotonicRegression().setWeightParam("w").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().setFeaturesCol("f").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().setLabelCol("l").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().fit(dataset).setFeaturesCol("f").transform(dataset) + } + } +} From be7be6d4c7d978c20e601d1f5f56ecb3479814cb Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 30 Jul 2015 16:04:23 -0700 Subject: [PATCH 39/50] [SPARK-6684] [MLLIB] [ML] Add checkpointing to GBTs Add checkpointing to GradientBoostedTrees, GBTClassifier, GBTRegressor CC: mengxr Author: Joseph K. Bradley Closes #7804 from jkbradley/gbt-checkpoint3 and squashes the following commits: 3fbd7ba [Joseph K. Bradley] tiny fix b3e160c [Joseph K. Bradley] unset checkpoint dir after test 9cc3a04 [Joseph K. Bradley] added checkpointing to GBTs --- .../spark/mllib/clustering/LDAOptimizer.scala | 1 + .../mllib/tree/GradientBoostedTrees.scala | 48 +++++------ .../tree/configuration/BoostingStrategy.scala | 3 +- .../classification/GBTClassifierSuite.scala | 20 +++++ .../ml/regression/GBTRegressorSuite.scala | 20 ++++- .../tree/GradientBoostedTreesSuite.scala | 79 +++++++++++-------- 6 files changed, 114 insertions(+), 57 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 9dbec41efeada..d6f8b29a43dfd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -144,6 +144,7 @@ final class EMLDAOptimizer extends LDAOptimizer { this.checkpointInterval = lda.getCheckpointInterval this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( checkpointInterval, graph.vertices.sparkContext) + this.graphCheckpointer.update(this.graph) this.globalTopicTotals = computeGlobalTopicTotals() this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index a835f96d5d0e3..9ce6faa137c41 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.BoostingStrategy import org.apache.spark.mllib.tree.configuration.Algo._ @@ -184,22 +185,28 @@ object GradientBoostedTrees extends Logging { false } + // Prepare periodic checkpointers + val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + timer.stop("init") logDebug("##########") logDebug("Building tree 0") logDebug("##########") - var data = input // Initialize tree timer.start("building tree 0") - val firstTreeModel = new DecisionTree(treeStrategy).run(data) + val firstTreeModel = new DecisionTree(treeStrategy).run(input) val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel baseLearnerWeights(0) = firstTreeWeight var predError: RDD[(Double, Double)] = GradientBoostedTreesModel. computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) + predErrorCheckpointer.update(predError) logDebug("error of gbt = " + predError.values.mean()) // Note: A model of type regression is used since we require raw prediction @@ -207,35 +214,34 @@ object GradientBoostedTrees extends Logging { var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel. computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) + if (validate) validatePredErrorCheckpointer.update(validatePredError) var bestValidateError = if (validate) validatePredError.values.mean() else 0.0 var bestM = 1 - // pseudo-residual for second iteration - data = predError.zip(input).map { case ((pred, _), point) => - LabeledPoint(-loss.gradient(pred, point.label), point.features) - } - var m = 1 - while (m < numIterations) { + var doneLearning = false + while (m < numIterations && !doneLearning) { + // Update data with pseudo-residuals + val data = predError.zip(input).map { case ((pred, _), point) => + LabeledPoint(-loss.gradient(pred, point.label), point.features) + } + timer.start(s"building tree $m") logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") val model = new DecisionTree(treeStrategy).run(data) timer.stop(s"building tree $m") - // Create partial model + // Update partial model baseLearners(m) = model // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. // Technically, the weight should be optimized for the particular loss. // However, the behavior should be reasonable, though not optimal. baseLearnerWeights(m) = learningRate - // Note: A model of type regression is used since we require raw prediction - val partialModel = new GradientBoostedTreesModel( - Regression, baseLearners.slice(0, m + 1), - baseLearnerWeights.slice(0, m + 1)) predError = GradientBoostedTreesModel.updatePredictionError( input, predError, baseLearnerWeights(m), baseLearners(m), loss) + predErrorCheckpointer.update(predError) logDebug("error of gbt = " + predError.values.mean()) if (validate) { @@ -246,21 +252,15 @@ object GradientBoostedTrees extends Logging { validatePredError = GradientBoostedTreesModel.updatePredictionError( validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) + validatePredErrorCheckpointer.update(validatePredError) val currentValidateError = validatePredError.values.mean() if (bestValidateError - currentValidateError < validationTol) { - return new GradientBoostedTreesModel( - boostingStrategy.treeStrategy.algo, - baseLearners.slice(0, bestM), - baseLearnerWeights.slice(0, bestM)) + doneLearning = true } else if (currentValidateError < bestValidateError) { - bestValidateError = currentValidateError - bestM = m + 1 + bestValidateError = currentValidateError + bestM = m + 1 } } - // Update data with pseudo-residuals - data = predError.zip(input).map { case ((pred, _), point) => - LabeledPoint(-loss.gradient(pred, point.label), point.features) - } m += 1 } @@ -269,6 +269,8 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + predErrorCheckpointer.deleteAllCheckpoints() + validatePredErrorCheckpointer.deleteAllCheckpoints() if (persistedInput) input.unpersist() if (validate) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 2d6b01524ff3d..9fd30c9b56319 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -36,7 +36,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * learning rate should be between in the interval (0, 1] * @param validationTol Useful when runWithValidation is used. If the error rate on the * validation input between two iterations is less than the validationTol - * then stop. Ignored when [[run]] is used. + * then stop. Ignored when + * [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used. */ @Experimental case class BoostingStrategy( diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 82c345491bb3c..a7bc77965fefd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.util.Utils /** @@ -76,6 +77,25 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) + val gbt = new GBTClassifier() + .setMaxDepth(2) + .setLossType("logistic") + .setMaxIter(5) + .setStepSize(0.1) + .setCheckpointInterval(2) + val model = gbt.fit(df) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 9682edcd9ba84..dbdce0c9dea54 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -25,7 +25,8 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.DataFrame +import org.apache.spark.util.Utils /** @@ -88,6 +89,23 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { assert(predictions.min() < -1) } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val df = sqlContext.createDataFrame(data) + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(5) + .setStepSize(0.1) + .setCheckpointInterval(2) + val model = gbt.fit(df) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 2521b3342181a..6fc9e8df621df 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -166,43 +166,58 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext val algos = Array(Regression, Regression, Classification) val losses = Array(SquaredError, AbsoluteError, LogLoss) - (algos zip losses) map { - case (algo, loss) => { - val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, - categoricalFeaturesInfo = Map.empty) - val boostingStrategy = - new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) - val gbtValidate = new GradientBoostedTrees(boostingStrategy) - .runWithValidation(trainRdd, validateRdd) - val numTrees = gbtValidate.numTrees - assert(numTrees !== numIterations) - - // Test that it performs better on the validation dataset. - val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) - val (errorWithoutValidation, errorWithValidation) = { - if (algo == Classification) { - val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) - (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) - } else { - (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) - } - } - assert(errorWithValidation <= errorWithoutValidation) - - // Test that results from evaluateEachIteration comply with runWithValidation. - // Note that convergenceTol is set to 0.0 - val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) - assert(evaluationArray.length === numIterations) - assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) - var i = 1 - while (i < numTrees) { - assert(evaluationArray(i) <= evaluationArray(i - 1)) - i += 1 + algos.zip(losses).foreach { case (algo, loss) => + val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty) + val boostingStrategy = + new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) + val gbtValidate = new GradientBoostedTrees(boostingStrategy) + .runWithValidation(trainRdd, validateRdd) + val numTrees = gbtValidate.numTrees + assert(numTrees !== numIterations) + + // Test that it performs better on the validation dataset. + val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) + val (errorWithoutValidation, errorWithValidation) = { + if (algo == Classification) { + val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) + } else { + (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) } } + assert(errorWithValidation <= errorWithoutValidation) + + // Test that results from evaluateEachIteration comply with runWithValidation. + // Note that convergenceTol is set to 0.0 + val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) + assert(evaluationArray.length === numIterations) + assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) + var i = 1 + while (i < numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } } } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, checkpointInterval = 2) + val boostingStrategy = new BoostingStrategy(treeStrategy, SquaredError, 5, 0.1) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + } private object GradientBoostedTreesSuite { From e7905a9395c1a002f50bab29e16a729e14d4ed6f Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 30 Jul 2015 16:15:43 -0700 Subject: [PATCH 40/50] [SPARK-9463] [ML] Expose model coefficients with names in SparkR RFormula Preview: ``` > summary(m) features coefficients 1 (Intercept) 1.6765001 2 Sepal_Length 0.3498801 3 Species.versicolor -0.9833885 4 Species.virginica -1.0075104 ``` Design doc from umbrella task: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit cc mengxr Author: Eric Liang Closes #7771 from ericl/summary and squashes the following commits: ccd54c3 [Eric Liang] second pass a5ca93b [Eric Liang] comments 2772111 [Eric Liang] clean up 70483ef [Eric Liang] fix test 7c247d4 [Eric Liang] Merge branch 'master' into summary 3c55024 [Eric Liang] working 8c539aa [Eric Liang] first pass --- R/pkg/NAMESPACE | 3 ++- R/pkg/R/mllib.R | 26 ++++++++++++++++++ R/pkg/inst/tests/test_mllib.R | 11 ++++++++ .../spark/ml/feature/OneHotEncoder.scala | 12 ++++----- .../apache/spark/ml/feature/RFormula.scala | 12 ++++++++- .../apache/spark/ml/r/SparkRWrappers.scala | 27 +++++++++++++++++-- .../ml/regression/LinearRegression.scala | 8 ++++-- .../spark/ml/feature/OneHotEncoderSuite.scala | 8 +++--- .../spark/ml/feature/RFormulaSuite.scala | 18 +++++++++++++ 9 files changed, 108 insertions(+), 17 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7f7a8a2e4de24..a329e14f25aeb 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -12,7 +12,8 @@ export("print.jobj") # MLlib integration exportMethods("glm", - "predict") + "predict", + "summary") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 6a8bacaa552c6..efddcc1d8d71c 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -71,3 +71,29 @@ setMethod("predict", signature(object = "PipelineModel"), function(object, newData) { return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) }) + +#' Get the summary of a model +#' +#' Returns the summary of a model produced by glm(), similarly to R's summary(). +#' +#' @param model A fitted MLlib model +#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See +#' summary.glm for more information. +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' summary(model) +#'} +setMethod("summary", signature(object = "PipelineModel"), + function(object) { + features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelFeatures", object@model) + weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelWeights", object@model) + coefficients <- as.matrix(unlist(weights)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 3bef69324770a..f272de78ad4a6 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -48,3 +48,14 @@ test_that("dot minus and intercept vs native glm", { rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) + +test_that("summary coefficients match with native glm", { + training <- createDataFrame(sqlContext, iris) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) + coefs <- as.vector(stats$coefficients) + rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) + expect_true(all(abs(rCoefs - coefs) < 1e-6)) + expect_true(all( + as.character(stats$features) == + c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica"))) +}) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 3825942795645..9c60d4084ec46 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { - val is = "_is_" val inputColName = $(inputCol) val outputColName = $(outputCol) @@ -79,17 +78,17 @@ class OneHotEncoder(override val uid: String) extends Transformer val outputAttrNames: Option[Array[String]] = inputAttr match { case nominal: NominalAttribute => if (nominal.values.isDefined) { - nominal.values.map(_.map(v => inputColName + is + v)) + nominal.values } else if (nominal.numValues.isDefined) { - nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i)) + nominal.numValues.map(n => Array.tabulate(n)(_.toString)) } else { None } case binary: BinaryAttribute => if (binary.values.isDefined) { - binary.values.map(_.map(v => inputColName + is + v)) + binary.values } else { - Some(Array.tabulate(2)(i => inputColName + is + i)) + Some(Array.tabulate(2)(_.toString)) } case _: NumericAttribute => throw new RuntimeException( @@ -123,7 +122,6 @@ class OneHotEncoder(override val uid: String) extends Transformer override def transform(dataset: DataFrame): DataFrame = { // schema transformation - val is = "_is_" val inputColName = $(inputCol) val outputColName = $(outputCol) val shouldDropLast = $(dropLast) @@ -142,7 +140,7 @@ class OneHotEncoder(override val uid: String) extends Transformer math.max(m0, m1) } ).toInt + 1 - val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i) + val outputAttrNames = Array.tabulate(numAttrs)(_.toString) val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames val outputAttrs: Array[Attribute] = filtered.map(name => BinaryAttribute.defaultAttr.withName(name)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 0b428d278d908..d1726917e4517 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.parsing.combinator.RegexParsers @@ -91,11 +92,20 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R // TODO(ekl) add support for feature interactions val encoderStages = ArrayBuffer[PipelineStage]() val tempColumns = ArrayBuffer[String]() + val takenNames = mutable.Set(dataset.columns: _*) val encodedTerms = resolvedFormula.terms.map { term => dataset.schema(term) match { case column if column.dataType == StringType => val indexCol = term + "_idx_" + uid - val encodedCol = term + "_onehot_" + uid + val encodedCol = { + var tmp = term + while (takenNames.contains(tmp)) { + tmp += "_" + } + tmp + } + takenNames.add(indexCol) + takenNames.add(encodedCol) encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol) encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol) tempColumns += indexCol diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 9f70592ccad7e..f5a022c31ed90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -17,9 +17,10 @@ package org.apache.spark.ml.api.r +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.feature.RFormula -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.sql.DataFrame @@ -44,4 +45,26 @@ private[r] object SparkRWrappers { val pipeline = new Pipeline().setStages(Array(formula, estimator)) pipeline.fit(df) } + + def getModelWeights(model: PipelineModel): Array[Double] = { + model.stages.last match { + case m: LinearRegressionModel => + Array(m.intercept) ++ m.weights.toArray + case _: LogisticRegressionModel => + throw new UnsupportedOperationException( + "No weights available for LogisticRegressionModel") // SPARK-9492 + } + } + + def getModelFeatures(model: PipelineModel): Array[String] = { + model.stages.last match { + case m: LinearRegressionModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + case _: LogisticRegressionModel => + throw new UnsupportedOperationException( + "No features names available for LogisticRegressionModel") // SPARK-9492 + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 89718e0f3e15a..3b85ba001b128 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.StructField import org.apache.spark.storage.StorageLevel import org.apache.spark.util.StatCounter @@ -146,9 +147,10 @@ class LinearRegression(override val uid: String) val model = new LinearRegressionModel(uid, weights, intercept) val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset).select($(predictionCol), $(labelCol)), + model.transform(dataset), $(predictionCol), $(labelCol), + $(featuresCol), Array(0D)) return copyValues(model.setSummary(trainingSummary)) } @@ -221,9 +223,10 @@ class LinearRegression(override val uid: String) val model = copyValues(new LinearRegressionModel(uid, weights, intercept)) val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset).select($(predictionCol), $(labelCol)), + model.transform(dataset), $(predictionCol), $(labelCol), + $(featuresCol), objectiveHistory) model.setSummary(trainingSummary) } @@ -300,6 +303,7 @@ class LinearRegressionTrainingSummary private[regression] ( predictions: DataFrame, predictionCol: String, labelCol: String, + val featuresCol: String, val objectiveHistory: Array[Double]) extends LinearRegressionSummary(predictions, predictionCol, labelCol) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 65846a846b7b4..321eeb843941c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -86,8 +86,8 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1)) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) } test("input column without ML attribute") { @@ -98,7 +98,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1)) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 8148c553e9051..6aed3243afce8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -105,4 +106,21 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(result.schema.toString == resultSchema.toString) assert(result.collect() === expected.collect()) } + + test("attribute generation") { + val formula = new RFormula().setFormula("id ~ a + b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array( + new BinaryAttribute(Some("a__bar"), Some(1)), + new BinaryAttribute(Some("a__foo"), Some(2)), + new NumericAttribute(Some("b"), Some(3)))) + assert(attrs === expectedAttrs) + } } From 157840d1b14502a4f25cff53633c927998c6ada1 Mon Sep 17 00:00:00 2001 From: Hossein Date: Thu, 30 Jul 2015 16:16:17 -0700 Subject: [PATCH 41/50] [SPARK-8742] [SPARKR] Improve SparkR error messages for DataFrame API This patch improves SparkR error message reporting, especially with DataFrame API. When there is a user error (e.g., malformed SQL query), the message of the cause is sent back through the RPC and the R client reads it and returns it back to user. cc shivaram Author: Hossein Closes #7742 from falaki/SPARK-8742 and squashes the following commits: 4f643c9 [Hossein] Not logging exceptions in RBackendHandler 4a8005c [Hossein] Returning stack track of causing exception from RBackendHandler 5cf17f0 [Hossein] Adding unit test for error messages from SQLContext 2af75d5 [Hossein] Reading error message in case of failure and stoping with that message f479c99 [Hossein] Wrting exception cause message in JVM --- R/pkg/R/backend.R | 4 +++- R/pkg/inst/tests/test_sparkSQL.R | 5 +++++ .../scala/org/apache/spark/api/r/RBackendHandler.scala | 10 ++++++++-- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R index 2fb6fae55f28c..49162838b8d1a 100644 --- a/R/pkg/R/backend.R +++ b/R/pkg/R/backend.R @@ -110,6 +110,8 @@ invokeJava <- function(isStatic, objId, methodName, ...) { # TODO: check the status code to output error information returnStatus <- readInt(conn) - stopifnot(returnStatus == 0) + if (returnStatus != 0) { + stop(readString(conn)) + } readObject(conn) } diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index d5db97248c770..61c8a7ec7d837 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1002,6 +1002,11 @@ test_that("crosstab() on a DataFrame", { expect_identical(expected, ordered) }) +test_that("SQL error message is returned from JVM", { + retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) + expect_equal(grepl("Table Not Found: blah", retError), TRUE) +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index a5de10fe89c42..14dac4ed28ce3 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -69,8 +69,11 @@ private[r] class RBackendHandler(server: RBackend) case e: Exception => logError(s"Removing $objId failed", e) writeInt(dos, -1) + writeString(dos, s"Removing $objId failed: ${e.getMessage}") } - case _ => dos.writeInt(-1) + case _ => + dos.writeInt(-1) + writeString(dos, s"Error: unknown method $methodName") } } else { handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) @@ -146,8 +149,11 @@ private[r] class RBackendHandler(server: RBackend) } } catch { case e: Exception => - logError(s"$methodName on $objId failed", e) + logError(s"$methodName on $objId failed") writeInt(dos, -1) + // Writing the error message of the cause for the exception. This will be returned + // to user in the R process. + writeString(dos, Utils.exceptionString(e.getCause)) } } From 04c8409107710fc9a625ee513d68c149745539f3 Mon Sep 17 00:00:00 2001 From: Calvin Jia Date: Thu, 30 Jul 2015 16:32:40 -0700 Subject: [PATCH 42/50] [SPARK-9199] [CORE] Update Tachyon dependency from 0.6.4 -> 0.7.0 No new dependencies are added. The exclusion changes are due to the change in tachyon-client 0.7.0's project structure. There is no client side API change in Tachyon 0.7.0 so no code changes are required. Author: Calvin Jia Closes #7577 from calvinjia/SPARK-9199 and squashes the following commits: 4e81e40 [Calvin Jia] Update Tachyon dependency from 0.6.4 -> 0.7.0 --- core/pom.xml | 34 +++++----------------------------- make-distribution.sh | 2 +- 2 files changed, 6 insertions(+), 30 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 6fa87ec6a24af..202678779150b 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -286,7 +286,7 @@ org.tachyonproject tachyon-client - 0.6.4 + 0.7.0 org.apache.hadoop @@ -297,36 +297,12 @@ curator-recipes - org.eclipse.jetty - jetty-jsp + org.tachyonproject + tachyon-underfs-glusterfs - org.eclipse.jetty - jetty-webapp - - - org.eclipse.jetty - jetty-server - - - org.eclipse.jetty - jetty-servlet - - - junit - junit - - - org.powermock - powermock-module-junit4 - - - org.powermock - powermock-api-mockito - - - org.apache.curator - curator-test + org.tachyonproject + tachyon-underfs-s3 diff --git a/make-distribution.sh b/make-distribution.sh index cac7032bb2e87..4789b0e09cc8a 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -33,7 +33,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.6.4" +TACHYON_VERSION="0.7.0" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" From 1afdeb7b458f86e2641f062fb9ddc00e9c5c7531 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 30 Jul 2015 16:44:02 -0700 Subject: [PATCH 43/50] [STREAMING] [TEST] [HOTFIX] Fixed Kinesis test to not throw weird errors when Kinesis tests are enabled without AWS keys If Kinesis tests are enabled by env ENABLE_KINESIS_TESTS = 1 but no AWS credentials are found, the desired behavior is the fail the test using with ``` Exception encountered when attempting to run a suite with class name: org.apache.spark.streaming.kinesis.KinesisBackedBlockRDDSuite *** ABORTED *** (3 seconds, 5 milliseconds) [info] java.lang.Exception: Kinesis tests enabled, but could get not AWS credentials ``` Instead KinesisStreamSuite fails with ``` [info] - basic operation *** FAILED *** (3 seconds, 35 milliseconds) [info] java.lang.IllegalArgumentException: requirement failed: Stream not yet created, call createStream() to create one [info] at scala.Predef$.require(Predef.scala:233) [info] at org.apache.spark.streaming.kinesis.KinesisTestUtils.streamName(KinesisTestUtils.scala:77) [info] at org.apache.spark.streaming.kinesis.KinesisTestUtils$$anonfun$deleteStream$1.apply(KinesisTestUtils.scala:150) [info] at org.apache.spark.streaming.kinesis.KinesisTestUtils$$anonfun$deleteStream$1.apply(KinesisTestUtils.scala:150) [info] at org.apache.spark.Logging$class.logWarning(Logging.scala:71) [info] at org.apache.spark.streaming.kinesis.KinesisTestUtils.logWarning(KinesisTestUtils.scala:39) [info] at org.apache.spark.streaming.kinesis.KinesisTestUtils.deleteStream(KinesisTestUtils.scala:150) [info] at org.apache.spark.streaming.kinesis.KinesisStreamSuite$$anonfun$3.apply$mcV$sp(KinesisStreamSuite.scala:111) [info] at org.apache.spark.streaming.kinesis.KinesisStreamSuite$$anonfun$3.apply(KinesisStreamSuite.scala:86) [info] at org.apache.spark.streaming.kinesis.KinesisStreamSuite$$anonfun$3.apply(KinesisStreamSuite.scala:86) ``` This is because attempting to delete a non-existent Kinesis stream throws uncaught exception. This PR fixes it. Author: Tathagata Das Closes #7809 from tdas/kinesis-test-hotfix and squashes the following commits: 7c372e6 [Tathagata Das] Fixed test --- .../streaming/kinesis/KinesisTestUtils.scala | 27 ++++++++++--------- .../kinesis/KinesisStreamSuite.scala | 4 +-- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 0ff1b7ed0fd90..ca39358b75cb6 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -53,6 +53,8 @@ private class KinesisTestUtils( @volatile private var streamCreated = false + + @volatile private var _streamName: String = _ private lazy val kinesisClient = { @@ -115,21 +117,9 @@ private class KinesisTestUtils( shardIdToSeqNumbers.toMap } - def describeStream(streamNameToDescribe: String = streamName): Option[StreamDescription] = { - try { - val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe) - val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription() - Some(desc) - } catch { - case rnfe: ResourceNotFoundException => - None - } - } - def deleteStream(): Unit = { try { - if (describeStream().nonEmpty) { - val deleteStreamRequest = new DeleteStreamRequest() + if (streamCreated) { kinesisClient.deleteStream(streamName) } } catch { @@ -149,6 +139,17 @@ private class KinesisTestUtils( } } + private def describeStream(streamNameToDescribe: String): Option[StreamDescription] = { + try { + val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe) + val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription() + Some(desc) + } catch { + case rnfe: ResourceNotFoundException => + None + } + } + private def findNonExistentStreamName(): String = { var testStreamName: String = null do { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index f9c952b9468bb..b88c9c6478d56 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -88,11 +88,11 @@ class KinesisStreamSuite extends KinesisFunSuite try { kinesisTestUtils.createStream() ssc = new StreamingContext(sc, Seconds(1)) - val aWSCredentials = KinesisTestUtils.getAWSCredentials() + val awsCredentials = KinesisTestUtils.getAWSCredentials() val stream = KinesisUtils.createStream(ssc, kinesisAppName, kinesisTestUtils.streamName, kinesisTestUtils.endpointUrl, kinesisTestUtils.regionName, InitialPositionInStream.LATEST, Seconds(10), StorageLevel.MEMORY_ONLY, - aWSCredentials.getAWSAccessKeyId, aWSCredentials.getAWSSecretKey) + awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => From ca71cc8c8b2d64b7756ae697c06876cd18b536dc Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 30 Jul 2015 16:57:38 -0700 Subject: [PATCH 44/50] [SPARK-9408] [PYSPARK] [MLLIB] Refactor linalg.py to /linalg This is based on MechCoder 's PR https://github.com/apache/spark/pull/7731. Hopefully it could pass tests. MechCoder I tried to make minimal changes. If this passes Jenkins, we can merge this one first and then try to move `__init__.py` to `local.py` in a separate PR. Closes #7731 Author: Xiangrui Meng Closes #7746 from mengxr/SPARK-9408 and squashes the following commits: 0e05a3b [Xiangrui Meng] merge master 1135551 [Xiangrui Meng] add a comment for str(...) c48cae0 [Xiangrui Meng] update tests 173a805 [Xiangrui Meng] move linalg.py to linalg/__init__.py --- dev/sparktestsupport/modules.py | 2 +- python/pyspark/mllib/{linalg.py => linalg/__init__.py} | 0 python/pyspark/sql/types.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename python/pyspark/mllib/{linalg.py => linalg/__init__.py} (100%) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 030d982e99106..44600cb9523c1 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -323,7 +323,7 @@ def contains_file(self, filename): "pyspark.mllib.evaluation", "pyspark.mllib.feature", "pyspark.mllib.fpm", - "pyspark.mllib.linalg", + "pyspark.mllib.linalg.__init__", "pyspark.mllib.random", "pyspark.mllib.recommendation", "pyspark.mllib.regression", diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg/__init__.py similarity index 100% rename from python/pyspark/mllib/linalg.py rename to python/pyspark/mllib/linalg/__init__.py diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 0976aea72c034..6f74b7162f7cc 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -648,7 +648,7 @@ def jsonValue(self): @classmethod def fromJson(cls, json): - pyUDT = str(json["pyClass"]) + pyUDT = str(json["pyClass"]) # convert unicode to str split = pyUDT.rfind(".") pyModule = pyUDT[:split] pyClass = pyUDT[split+1:] From df32669514afc0223ecdeca30fbfbe0b40baef3a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 30 Jul 2015 17:16:03 -0700 Subject: [PATCH 45/50] [SPARK-7157][SQL] add sampleBy to DataFrame This was previously committed but then reverted due to test failures (see #6769). Author: Xiangrui Meng Closes #7755 from rxin/SPARK-7157 and squashes the following commits: fbf9044 [Xiangrui Meng] fix python test 542bd37 [Xiangrui Meng] update test 604fe6d [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 f051afd [Xiangrui Meng] use udf instead of building expression f4e9425 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 8fb990b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 103beb3 [Xiangrui Meng] add Java-friendly sampleBy 991f26f [Xiangrui Meng] fix seed 4a14834 [Xiangrui Meng] move sampleBy to stat 832f7cc [Xiangrui Meng] add sampleBy to DataFrame --- python/pyspark/sql/dataframe.py | 41 ++++++++++++++++++ .../spark/sql/DataFrameStatFunctions.scala | 42 +++++++++++++++++++ .../apache/spark/sql/JavaDataFrameSuite.java | 9 ++++ .../apache/spark/sql/DataFrameStatSuite.scala | 12 +++++- 4 files changed, 102 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d76e051bd73a1..0f3480c239187 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -441,6 +441,42 @@ def sample(self, withReplacement, fraction, seed=None): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) + @since(1.5) + def sampleBy(self, col, fractions, seed=None): + """ + Returns a stratified sample without replacement based on the + fraction given on each stratum. + + :param col: column that defines strata + :param fractions: + sampling fraction for each stratum. If a stratum is not + specified, we treat its fraction as zero. + :param seed: random seed + :return: a new DataFrame that represents the stratified sample + + >>> from pyspark.sql.functions import col + >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key")) + >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0) + >>> sampled.groupBy("key").count().orderBy("key").show() + +---+-----+ + |key|count| + +---+-----+ + | 0| 3| + | 1| 8| + +---+-----+ + + """ + if not isinstance(col, str): + raise ValueError("col must be a string, but got %r" % type(col)) + if not isinstance(fractions, dict): + raise ValueError("fractions must be a dict but got %r" % type(fractions)) + for k, v in fractions.items(): + if not isinstance(k, (float, int, long, basestring)): + raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) + fractions[k] = float(v) + seed = seed if seed is not None else random.randint(0, sys.maxsize) + return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) + @since(1.4) def randomSplit(self, weights, seed=None): """Randomly splits this :class:`DataFrame` with the provided weights. @@ -1314,6 +1350,11 @@ def freqItems(self, cols, support=None): freqItems.__doc__ = DataFrame.freqItems.__doc__ + def sampleBy(self, col, fractions, seed=None): + return self.df.sampleBy(col, fractions, seed) + + sampleBy.__doc__ = DataFrame.sampleBy.__doc__ + def _test(): import doctest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 4ec58082e7aef..2e68e358f2f1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import java.{util => ju, lang => jl} + +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.stat._ @@ -166,4 +170,42 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def freqItems(cols: Seq[String]): DataFrame = { FrequentItems.singlePassFreqItems(df, cols, 0.01) } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new [[DataFrame]] that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { + require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), + s"Fractions must be in [0, 1], but got $fractions.") + import org.apache.spark.sql.functions.{rand, udf} + val c = Column(col) + val r = rand(seed) + val f = udf { (stratum: Any, x: Double) => + x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0) + } + df.filter(f(c, r)) + } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new [[DataFrame]] that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { + sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 9e61d06f4036e..2c669bb59a0b5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -226,4 +226,13 @@ public void testCovariance() { Double result = df.stat().cov("a", "b"); Assert.assertTrue(Math.abs(result) < 1e-6); } + + @Test + public void testSampleBy() { + DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key")); + DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); + Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; + Assert.assertArrayEquals(expected, actual); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 7ba4ba73e0cc9..07a675e64f527 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -21,9 +21,9 @@ import java.util.Random import org.scalatest.Matchers._ -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.functions.col -class DataFrameStatSuite extends SparkFunSuite { +class DataFrameStatSuite extends QueryTest { private val sqlCtx = org.apache.spark.sql.test.TestSQLContext import sqlCtx.implicits._ @@ -130,4 +130,12 @@ class DataFrameStatSuite extends SparkFunSuite { val items2 = singleColResults.collect().head items2.getSeq[Double](0) should contain (-1.0) } + + test("sampleBy") { + val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key")) + val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) + checkAnswer( + sampled.groupBy("key").count().orderBy("key"), + Seq(Row(0, 5), Row(1, 8))) + } } From e7a0976e991f75a7bda99509e2b040daab965ae6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 30 Jul 2015 17:17:27 -0700 Subject: [PATCH 46/50] [SPARK-9458][SPARK-9469][SQL] Code generate prefix computation in sorting & moves unsafe conversion out of TungstenSort. Author: Reynold Xin Closes #7803 from rxin/SPARK-9458 and squashes the following commits: 5b032dc [Reynold Xin] Fix string. b670dbb [Reynold Xin] [SPARK-9458][SPARK-9469][SQL] Code generate prefix computation in sorting & moves unsafe conversion out of TungstenSort. --- .../unsafe/sort/PrefixComparators.java | 49 ++++++++------ .../unsafe/sort/PrefixComparatorsSuite.scala | 22 ++----- .../execution/UnsafeExternalRowSorter.java | 27 ++++---- .../sql/catalyst/expressions/SortOrder.scala | 44 ++++++++++++- .../spark/sql/execution/SortPrefixUtils.scala | 64 +++---------------- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../sql/execution/joins/HashedRelation.scala | 4 +- .../org/apache/spark/sql/execution/sort.scala | 64 ++++++++----------- .../execution/RowFormatConvertersSuite.scala | 11 ++-- ...ortSuite.scala => TungstenSortSuite.scala} | 10 +-- 10 files changed, 138 insertions(+), 161 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/execution/{UnsafeExternalSortSuite.scala => TungstenSortSuite.scala} (87%) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 600aff7d15d8a..4d7e5b3dfba6e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -28,9 +28,11 @@ public class PrefixComparators { private PrefixComparators() {} public static final StringPrefixComparator STRING = new StringPrefixComparator(); - public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator(); - public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator(); + public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc(); + public static final LongPrefixComparator LONG = new LongPrefixComparator(); + public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc(); public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); + public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc(); public static final class StringPrefixComparator extends PrefixComparator { @Override @@ -38,50 +40,55 @@ public int compare(long aPrefix, long bPrefix) { return UnsignedLongs.compare(aPrefix, bPrefix); } - public long computePrefix(UTF8String value) { + public static long computePrefix(UTF8String value) { return value == null ? 0L : value.getPrefix(); } } - /** - * Prefix comparator for all integral types (boolean, byte, short, int, long). - */ - public static final class IntegralPrefixComparator extends PrefixComparator { + public static final class StringPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long bPrefix, long aPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); + } + } + + public static final class LongPrefixComparator extends PrefixComparator { @Override public int compare(long a, long b) { return (a < b) ? -1 : (a > b) ? 1 : 0; } + } - public final long NULL_PREFIX = Long.MIN_VALUE; + public static final class LongPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long b, long a) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } } - public static final class FloatPrefixComparator extends PrefixComparator { + public static final class DoublePrefixComparator extends PrefixComparator { @Override public int compare(long aPrefix, long bPrefix) { - float a = Float.intBitsToFloat((int) aPrefix); - float b = Float.intBitsToFloat((int) bPrefix); - return Utils.nanSafeCompareFloats(a, b); + double a = Double.longBitsToDouble(aPrefix); + double b = Double.longBitsToDouble(bPrefix); + return Utils.nanSafeCompareDoubles(a, b); } - public long computePrefix(float value) { - return Float.floatToIntBits(value) & 0xffffffffL; + public static long computePrefix(double value) { + return Double.doubleToLongBits(value); } - - public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY); } - public static final class DoublePrefixComparator extends PrefixComparator { + public static final class DoublePrefixComparatorDesc extends PrefixComparator { @Override - public int compare(long aPrefix, long bPrefix) { + public int compare(long bPrefix, long aPrefix) { double a = Double.longBitsToDouble(aPrefix); double b = Double.longBitsToDouble(bPrefix); return Utils.nanSafeCompareDoubles(a, b); } - public long computePrefix(double value) { + public static long computePrefix(double value) { return Double.doubleToLongBits(value); } - - public final long NULL_PREFIX = computePrefix(Double.NEGATIVE_INFINITY); } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index cf53a8ad21c60..26a2e96edaaa2 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -29,8 +29,8 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { def testPrefixComparison(s1: String, s2: String): Unit = { val utf8string1 = UTF8String.fromString(s1) val utf8string2 = UTF8String.fromString(s2) - val s1Prefix = PrefixComparators.STRING.computePrefix(utf8string1) - val s2Prefix = PrefixComparators.STRING.computePrefix(utf8string2) + val s1Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string1) + val s2Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string2) val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix) val cmp = UnsignedBytes.lexicographicalComparator().compare( @@ -55,27 +55,15 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } - test("float prefix comparator handles NaN properly") { - val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001) - val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff) - assert(nan1.isNaN) - assert(nan2.isNaN) - val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1) - val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2) - assert(nan1Prefix === nan2Prefix) - val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue) - assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1) - } - test("double prefix comparator handles NaNs properly") { val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) assert(nan1.isNaN) assert(nan2.isNaN) - val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1) - val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2) + val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1) + val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2) assert(nan1Prefix === nan2Prefix) - val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue) + val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 4c3f2c6557140..68c49feae938e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -48,7 +48,6 @@ final class UnsafeExternalRowSorter { private long numRowsInserted = 0; private final StructType schema; - private final UnsafeProjection unsafeProjection; private final PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; @@ -62,7 +61,6 @@ public UnsafeExternalRowSorter( PrefixComparator prefixComparator, PrefixComputer prefixComputer) throws IOException { this.schema = schema; - this.unsafeProjection = UnsafeProjection.create(schema); this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); final TaskContext taskContext = TaskContext.get(); @@ -88,13 +86,12 @@ void setTestSpillFrequency(int frequency) { } @VisibleForTesting - void insertRow(InternalRow row) throws IOException { - UnsafeRow unsafeRow = unsafeProjection.apply(row); + void insertRow(UnsafeRow row) throws IOException { final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( - unsafeRow.getBaseObject(), - unsafeRow.getBaseOffset(), - unsafeRow.getSizeInBytes(), + row.getBaseObject(), + row.getBaseOffset(), + row.getSizeInBytes(), prefix ); numRowsInserted++; @@ -113,7 +110,7 @@ private void cleanupResources() { } @VisibleForTesting - Iterator sort() throws IOException { + Iterator sort() throws IOException { try { final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); if (!sortedIterator.hasNext()) { @@ -121,7 +118,7 @@ Iterator sort() throws IOException { // here in order to prevent memory leaks. cleanupResources(); } - return new AbstractScalaRowIterator() { + return new AbstractScalaRowIterator() { private final int numFields = schema.length(); private UnsafeRow row = new UnsafeRow(); @@ -132,7 +129,7 @@ public boolean hasNext() { } @Override - public InternalRow next() { + public UnsafeRow next() { try { sortedIterator.loadNext(); row.pointTo( @@ -164,11 +161,11 @@ public InternalRow next() { } - public Iterator sort(Iterator inputIterator) throws IOException { - while (inputIterator.hasNext()) { - insertRow(inputIterator.next()); - } - return sort(); + public Iterator sort(Iterator inputIterator) throws IOException { + while (inputIterator.hasNext()) { + insertRow(inputIterator.next()); + } + return sort(); } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 3f436c0eb893c..9fe877f10fa08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator abstract sealed class SortDirection case object Ascending extends SortDirection @@ -37,4 +40,43 @@ case class SortOrder(child: Expression, direction: SortDirection) override def nullable: Boolean = child.nullable override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" + + def isAscending: Boolean = direction == Ascending +} + +/** + * An expression to generate a 64-bit long prefix used in sorting. + */ +case class SortPrefix(child: SortOrder) extends UnaryExpression { + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val childCode = child.child.gen(ctx) + val input = childCode.primitive + val DoublePrefixCmp = classOf[DoublePrefixComparator].getName + + val (nullValue: Long, prefixCode: String) = child.child.dataType match { + case BooleanType => + (Long.MinValue, s"$input ? 1L : 0L") + case _: IntegralType => + (Long.MinValue, s"(long) $input") + case FloatType | DoubleType => + (DoublePrefixComparator.computePrefix(Double.NegativeInfinity), + s"$DoublePrefixCmp.computePrefix((double)$input)") + case StringType => (0L, s"$input.getPrefix()") + case _ => (0L, "0L") + } + + childCode.code + + s""" + |long ${ev.primitive} = ${nullValue}L; + |boolean ${ev.isNull} = false; + |if (!${childCode.isNull}) { + | ${ev.primitive} = $prefixCode; + |} + """.stripMargin + } + + override def dataType: DataType = LongType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 2dee3542d6101..a2145b185ce90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} @@ -37,61 +35,15 @@ object SortPrefixUtils { def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = { sortOrder.dataType match { - case StringType => PrefixComparators.STRING - case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL - case FloatType => PrefixComparators.FLOAT - case DoubleType => PrefixComparators.DOUBLE + case StringType if sortOrder.isAscending => PrefixComparators.STRING + case StringType if !sortOrder.isAscending => PrefixComparators.STRING_DESC + case BooleanType | ByteType | ShortType | IntegerType | LongType if sortOrder.isAscending => + PrefixComparators.LONG + case BooleanType | ByteType | ShortType | IntegerType | LongType if !sortOrder.isAscending => + PrefixComparators.LONG_DESC + case FloatType | DoubleType if sortOrder.isAscending => PrefixComparators.DOUBLE + case FloatType | DoubleType if !sortOrder.isAscending => PrefixComparators.DOUBLE_DESC case _ => NoOpPrefixComparator } } - - def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = { - sortOrder.dataType match { - case StringType => (row: InternalRow) => { - PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String]) - } - case BooleanType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1 - else 0 - } - case ByteType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Byte] - } - case ShortType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Short] - } - case IntegerType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Int] - } - case LongType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Long] - } - case FloatType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX - else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float]) - } - case DoubleType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX - else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double]) - } - case _ => (row: InternalRow) => 0L - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 52a9b02d373c7..03d24a88d4ecd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -341,8 +341,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && - UnsafeExternalSort.supportsSchema(child.schema)) { - execution.UnsafeExternalSort(sortExprs, global, child) + TungstenSort.supportsSchema(child.schema)) { + execution.TungstenSort(sortExprs, global, child) } else if (sqlContext.conf.externalSortEnabled) { execution.ExternalSort(sortExprs, global, child) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 26dbc911e9521..f88a45f48aee9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -229,7 +229,7 @@ private[joins] final class UnsafeHashedRelation( // write all the values as single byte array var totalSize = 0L var i = 0 - while (i < values.size) { + while (i < values.length) { totalSize += values(i).getSizeInBytes + 4 + 4 i += 1 } @@ -240,7 +240,7 @@ private[joins] final class UnsafeHashedRelation( out.writeInt(totalSize.toInt) out.write(key.getBytes) i = 0 - while (i < values.size) { + while (i < values.length) { // [num of fields] [num of bytes] [row bytes] // write the integer in native order, so they can be read by UNSAFE.getInt() if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index f82208868c3e3..6d903ab23c57f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -17,16 +17,14 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.{Descending, BindReferences, Attribute, SortOrder} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution, Distribution} import org.apache.spark.sql.types.StructType import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.util.collection.unsafe.sort.PrefixComparator //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines various sort operators. @@ -97,59 +95,53 @@ case class ExternalSort( * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will * spill every `frequency` records. */ -case class UnsafeExternalSort( +case class TungstenSort( sortOrder: Seq[SortOrder], global: Boolean, child: SparkPlan, testSpillFrequency: Int = 0) extends UnaryNode { - private[this] val schema: StructType = child.schema + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled") - def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { - val ordering = newOrdering(sortOrder, child.output) - val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output) - // Hack until we generate separate comparator implementations for ascending vs. descending - // (or choose to codegen them): - val prefixComparator = { - val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression) - if (sortOrder.head.direction == Descending) { - new PrefixComparator { - override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2) - } - } else { - comp - } - } - val prefixComputer = { - val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression) - new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = prefixComputer(row) + protected override def doExecute(): RDD[InternalRow] = { + val schema = child.schema + val childOutput = child.output + child.execute().mapPartitions({ iter => + val ordering = newOrdering(sortOrder, childOutput) + + // The comparator for comparing prefix + val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput) + val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + + // The generator for prefix + val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) } } + val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer) if (testSpillFrequency > 0) { sorter.setTestSpillFrequency(testSpillFrequency) } - sorter.sort(iterator) - } - child.execute().mapPartitions(doSort, preservesPartitioning = true) + sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + }, preservesPartitioning = true) } - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder - - override def outputsUnsafeRows: Boolean = true } -@DeveloperApi -object UnsafeExternalSort { +object TungstenSort { /** * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 7b75f755918c1..707cd9c6d939b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.IsNull +import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull} import org.apache.spark.sql.test.TestSQLContext class RowFormatConvertersSuite extends SparkPlanTest { @@ -31,7 +30,7 @@ class RowFormatConvertersSuite extends SparkPlanTest { private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) + private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(outputsUnsafe.outputsUnsafeRows) test("planner should insert unsafe->safe conversions when required") { @@ -41,14 +40,14 @@ class RowFormatConvertersSuite extends SparkPlanTest { } test("filter can process unsafe rows") { - val plan = Filter(IsNull(null), outputsUnsafe) + val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).isEmpty) + assert(getConverters(preparedPlan).size === 1) assert(preparedPlan.outputsUnsafeRows) } test("filter can process safe rows") { - val plan = Filter(IsNull(null), outputsSafe) + val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).isEmpty) assert(!preparedPlan.outputsUnsafeRows) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala similarity index 87% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 138636b0c65b8..450963547c798 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { +class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { override def beforeAll(): Unit = { TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) @@ -39,7 +39,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { test("sort followed by limit") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), sortAnswers = false ) @@ -50,7 +50,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { val stringLength = 1024 * 1024 * 2 checkThatPlansAgree( Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), - UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) @@ -70,11 +70,11 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) - assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) + assert(TungstenSort.supportsSchema(inputDf.schema)) checkThatPlansAgree( inputDf, plan => ConvertToSafe( - UnsafeExternalSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), + TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) From 0b1a464b6e061580a75b99a91b042069d76bbbfd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 30 Jul 2015 17:18:32 -0700 Subject: [PATCH 47/50] [SPARK-9425] [SQL] support DecimalType in UnsafeRow This PR brings the support of DecimalType in UnsafeRow, for precision <= 18, it's settable, otherwise it's not settable. Author: Davies Liu Closes #7758 from davies/unsafe_decimal and squashes the following commits: 478b1ba [Davies Liu] address comments 536314c [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_decimal 7c2e77a [Davies Liu] fix JoinedRow 76d6fa4 [Davies Liu] fix tests 99d3151 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_decimal d49c6ae [Davies Liu] support DecimalType in UnsafeRow --- .../expressions/SpecializedGetters.java | 2 +- .../UnsafeFixedWidthAggregationMap.java | 22 ++-- .../sql/catalyst/expressions/UnsafeRow.java | 53 +++++--- .../expressions/UnsafeRowWriters.java | 42 +++++++ .../sql/catalyst/CatalystTypeConverters.scala | 9 +- .../spark/sql/catalyst/InternalRow.scala | 4 +- .../sql/catalyst/expressions/Projection.scala | 7 +- .../expressions/codegen/CodeGenerator.scala | 9 +- .../codegen/GenerateUnsafeProjection.scala | 115 ++++++++++-------- .../spark/sql/catalyst/expressions/rows.scala | 3 +- .../org/apache/spark/sql/types/Decimal.scala | 6 +- .../spark/sql/types/GenericArrayData.scala | 2 +- .../sql/catalyst/expressions/CastSuite.scala | 5 +- .../expressions/DateExpressionsSuite.scala | 2 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 8 +- .../expressions/UnsafeRowConverterSuite.scala | 17 +-- .../spark/sql/columnar/ColumnBuilder.scala | 2 +- .../spark/sql/columnar/ColumnStats.scala | 4 +- .../spark/sql/columnar/ColumnType.scala | 2 +- .../sql/execution/GeneratedAggregate.scala | 2 +- .../sql/execution/SparkSqlSerializer2.scala | 2 +- .../sql/parquet/ParquetTableSupport.scala | 4 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 40 +++++- 23 files changed, 237 insertions(+), 125 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java index f7cea13688876..e3d3ba7a9ccc0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java @@ -41,7 +41,7 @@ public interface SpecializedGetters { double getDouble(int ordinal); - Decimal getDecimal(int ordinal); + Decimal getDecimal(int ordinal, int precision, int scale); UTF8String getUTF8String(int ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 03f4c3ed8e6bb..f3b462778dc10 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -20,6 +20,8 @@ import java.util.Iterator; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; @@ -61,26 +63,18 @@ public final class UnsafeFixedWidthAggregationMap { private final boolean enablePerfMetrics; - /** - * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema, - * false otherwise. - */ - public static boolean supportsGroupKeySchema(StructType schema) { - for (StructField field: schema.fields()) { - if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) { - return false; - } - } - return true; - } - /** * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given * schema, false otherwise. */ public static boolean supportsAggregationBufferSchema(StructType schema) { for (StructField field: schema.fields()) { - if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { + if (field.dataType() instanceof DecimalType) { + DecimalType dt = (DecimalType) field.dataType(); + if (dt.precision() > Decimal.MAX_LONG_DIGITS()) { + return false; + } + } else if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { return false; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 6d684bac37573..e7088edced1a1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -19,6 +19,8 @@ import java.io.IOException; import java.io.OutputStream; +import java.math.BigDecimal; +import java.math.BigInteger; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -65,12 +67,7 @@ public static int calculateBitSetWidthInBytes(int numFields) { */ public static final Set settableFieldTypes; - /** - * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException). - */ - public static final Set readableFieldTypes; - - // TODO: support DecimalType + // DecimalType(precision <= 18) is settable static { settableFieldTypes = Collections.unmodifiableSet( new HashSet<>( @@ -86,16 +83,6 @@ public static int calculateBitSetWidthInBytes(int numFields) { DateType, TimestampType }))); - - // We support get() on a superset of the types for which we support set(): - final Set _readableFieldTypes = new HashSet<>( - Arrays.asList(new DataType[]{ - StringType, - BinaryType, - CalendarIntervalType - })); - _readableFieldTypes.addAll(settableFieldTypes); - readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); } ////////////////////////////////////////////////////////////////////////////// @@ -232,6 +219,21 @@ public void setFloat(int ordinal, float value) { PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } + @Override + public void setDecimal(int ordinal, Decimal value, int precision) { + assertIndexIsValid(ordinal); + if (value == null) { + setNullAt(ordinal); + } else { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + setLong(ordinal, value.toUnscaledLong()); + } else { + // TODO(davies): support update decimal (hold a bounded space even it's null) + throw new UnsupportedOperationException(); + } + } + } + @Override public Object get(int ordinal) { throw new UnsupportedOperationException(); @@ -256,7 +258,8 @@ public Object get(int ordinal, DataType dataType) { } else if (dataType instanceof DoubleType) { return getDouble(ordinal); } else if (dataType instanceof DecimalType) { - return getDecimal(ordinal); + DecimalType dt = (DecimalType) dataType; + return getDecimal(ordinal, dt.precision(), dt.scale()); } else if (dataType instanceof DateType) { return getInt(ordinal); } else if (dataType instanceof TimestampType) { @@ -322,6 +325,22 @@ public double getDouble(int ordinal) { return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); } + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + assertIndexIsValid(ordinal); + if (isNullAt(ordinal)) { + return null; + } + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(ordinal), precision, scale); + } else { + byte[] bytes = getBinary(ordinal); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale); + } + } + @Override public UTF8String getUTF8String(int ordinal) { assertIndexIsValid(ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index c3259e21c4a78..f43a285cd6cad 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; @@ -30,6 +31,47 @@ */ public class UnsafeRowWriters { + /** Writer for Decimal with precision under 18. */ + public static class CompactDecimalWriter { + + public static int getSize(Decimal input) { + return 0; + } + + public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) { + target.setLong(ordinal, input.toUnscaledLong()); + return 0; + } + } + + /** Writer for Decimal with precision larger than 18. */ + public static class DecimalWriter { + + public static int getSize(Decimal input) { + // bounded size + return 16; + } + + public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) { + final long offset = target.getBaseOffset() + cursor; + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + final int numBytes = bytes.length; + assert(numBytes <= 16); + + // zero-out the bytes + PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, 0L); + PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, 0L); + + // Write the bytes to the variable length portion. + PlatformDependent.copyMemory(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, + target.getBaseObject(), offset, numBytes); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + return 16; + } + } + /** Writer for UTF8String. */ public static class UTF8StringWriter { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 22452c0f201ef..7ca20fe97fbef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -68,7 +68,7 @@ object CatalystTypeConverters { case StringType => StringConverter case DateType => DateConverter case TimestampType => TimestampConverter - case dt: DecimalType => BigDecimalConverter + case dt: DecimalType => new DecimalConverter(dt) case BooleanType => BooleanConverter case ByteType => ByteConverter case ShortType => ShortConverter @@ -306,7 +306,8 @@ object CatalystTypeConverters { DateTimeUtils.toJavaTimestamp(row.getLong(column)) } - private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { + private class DecimalConverter(dataType: DecimalType) + extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match { case d: BigDecimal => Decimal(d) case d: JavaBigDecimal => Decimal(d) @@ -314,9 +315,11 @@ object CatalystTypeConverters { } override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal = - row.getDecimal(column).toJavaBigDecimal + row.getDecimal(column, dataType.precision, dataType.scale).toJavaBigDecimal } + private object BigDecimalConverter extends DecimalConverter(DecimalType.SYSTEM_DEFAULT) + private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { final override def toScala(catalystValue: Any): Any = catalystValue final override def toCatalystImpl(scalaValue: T): Any = scalaValue diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 486ba036548c8..b19bf4386b0ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -58,8 +58,8 @@ abstract class InternalRow extends Serializable with SpecializedGetters { override def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType) - override def getDecimal(ordinal: Int): Decimal = - getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = + getAs[Decimal](ordinal, DecimalType(precision, scale)) override def getInterval(ordinal: Int): CalendarInterval = getAs[CalendarInterval](ordinal, CalendarIntervalType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index b3beb7e28f208..7c7664e4c1a91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection} -import org.apache.spark.sql.types.{StructType, DataType} +import org.apache.spark.sql.types.{Decimal, StructType, DataType} import org.apache.spark.unsafe.types.UTF8String /** @@ -225,6 +225,11 @@ class JoinedRow extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) + override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = { + if (i < row1.numFields) row1.getDecimal(i, precision, scale) + else row2.getDecimal(i - row1.numFields, precision, scale) + } + override def getStruct(i: Int, numFields: Int): InternalRow = { if (i < row1.numFields) { row1.getStruct(i, numFields) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c39e0df6fae2a..60e2863f7bbb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -106,6 +106,7 @@ class CodeGenContext { val jt = javaType(dataType) dataType match { case _ if isPrimitiveType(jt) => s"$getter.get${primitiveTypeName(jt)}($ordinal)" + case t: DecimalType => s"$getter.getDecimal($ordinal, ${t.precision}, ${t.scale})" case StringType => s"$getter.getUTF8String($ordinal)" case BinaryType => s"$getter.getBinary($ordinal)" case CalendarIntervalType => s"$getter.getInterval($ordinal)" @@ -120,10 +121,10 @@ class CodeGenContext { */ def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { val jt = javaType(dataType) - if (isPrimitiveType(jt)) { - s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" - } else { - s"$row.update($ordinal, $value)" + dataType match { + case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" + case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" + case _ => s"$row.update($ordinal, $value)" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index a662357fb6cf9..1d223986d9441 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -35,6 +35,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName + private val CompactDecimalWriter = classOf[UnsafeRowWriters.CompactDecimalWriter].getName + private val DecimalWriter = classOf[UnsafeRowWriters.DecimalWriter].getName /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { @@ -42,9 +44,64 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _: CalendarIntervalType => true case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case NullType => true + case t: DecimalType => true case _ => false } + def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match { + case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => + s" + (${ev.isNull} ? 0 : $DecimalWriter.getSize(${ev.primitive}))" + case StringType => + s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" + case BinaryType => + s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" + case CalendarIntervalType => + s" + (${ev.isNull} ? 0 : 16)" + case _: StructType => + s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" + case _ => "" + } + + def genFieldWriter( + ctx: CodeGenContext, + fieldType: DataType, + ev: GeneratedExpressionCode, + primitive: String, + index: Int, + cursor: String): String = fieldType match { + case _ if ctx.isPrimitiveType(fieldType) => + s"${ctx.setColumn(primitive, fieldType, index, ev.primitive)}" + case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => + s""" + // make sure Decimal object has the same scale as DecimalType + if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { + $CompactDecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); + } else { + $primitive.setNullAt($index); + } + """ + case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => + s""" + // make sure Decimal object has the same scale as DecimalType + if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { + $cursor += $DecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); + } else { + $primitive.setNullAt($index); + } + """ + case StringType => + s"$cursor += $StringWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case BinaryType => + s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case CalendarIntervalType => + s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case t: StructType => + s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case NullType => "" + case _ => + throw new UnsupportedOperationException(s"Not supported DataType: $fieldType") + } + /** * Generates the code to create an [[UnsafeRow]] object based on the input expressions. * @param ctx context for code generation @@ -69,36 +126,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val allExprs = exprs.map(_.code).mkString("\n") val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) - val additionalSize = expressions.zipWithIndex.map { case (e, i) => - e.dataType match { - case StringType => - s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))" - case BinaryType => - s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))" - case CalendarIntervalType => - s" + (${exprs(i).isNull} ? 0 : 16)" - case _: StructType => - s" + (${exprs(i).isNull} ? 0 : $StructWriter.getSize(${exprs(i).primitive}))" - case _ => "" - } + val additionalSize = expressions.zipWithIndex.map { + case (e, i) => genAdditionalSize(e.dataType, exprs(i)) }.mkString("") val writers = expressions.zipWithIndex.map { case (e, i) => - val update = e.dataType match { - case dt if ctx.isPrimitiveType(dt) => - s"${ctx.setColumn(ret, dt, i, exprs(i).primitive)}" - case StringType => - s"$cursor += $StringWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" - case BinaryType => - s"$cursor += $BinaryWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" - case CalendarIntervalType => - s"$cursor += $IntervalWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" - case t: StructType => - s"$cursor += $StructWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" - case NullType => "" - case _ => - throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") - } + val update = genFieldWriter(ctx, e.dataType, exprs(i), ret, i, cursor) s"""if (${exprs(i).isNull}) { $ret.setNullAt($i); } else { @@ -168,35 +201,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) => - dt match { - case StringType => - s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" - case BinaryType => - s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" - case CalendarIntervalType => - s" + (${ev.isNull} ? 0 : 16)" - case _: StructType => - s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" - case _ => "" - } + genAdditionalSize(dt, ev) }.mkString("") val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) => - val update = dt match { - case _ if ctx.isPrimitiveType(dt) => - s"${ctx.setColumn(primitive, dt, i, exprs(i).primitive)}" - case StringType => - s"$cursor += $StringWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" - case BinaryType => - s"$cursor += $BinaryWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" - case CalendarIntervalType => - s"$cursor += $IntervalWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" - case t: StructType => - s"$cursor += $StructWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" - case NullType => "" - case _ => - throw new UnsupportedOperationException(s"Not supported DataType: $dt") - } + val update = genFieldWriter(ctx, dt, ev, primitive, i, cursor) s""" if (${exprs(i).isNull}) { $primitive.setNullAt($i); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index b7c4ece4a16fe..df6ea586c87ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.{DataType, StructType, AtomicType} +import org.apache.spark.sql.types.{Decimal, DataType, StructType, AtomicType} import org.apache.spark.unsafe.types.UTF8String /** @@ -39,6 +39,7 @@ abstract class MutableRow extends InternalRow { def setShort(i: Int, value: Short): Unit = { update(i, value) } def setByte(i: Int, value: Byte): Unit = { update(i, value) } def setFloat(i: Int, value: Float): Unit = { update(i, value) } + def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) } def setString(i: Int, value: String): Unit = { update(i, UTF8String.fromString(value)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index bc689810bc292..c0155eeb450a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -188,6 +188,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { * @return true if successful, false if overflow would occur */ def changePrecision(precision: Int, scale: Int): Boolean = { + // fast path for UnsafeProjection + if (precision == this.precision && scale == this.scale) { + return true + } // First, update our longVal if we can, or transfer over to using a BigDecimal if (decimalVal.eq(null)) { if (scale < _scale) { @@ -224,7 +228,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { decimalVal = newVal } else { // We're still using Longs, but we should check whether we match the new precision - val p = POW_10(math.min(_precision, MAX_LONG_DIGITS)) + val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) if (longVal <= -p || longVal >= p) { // Note that we shouldn't have been able to fix this by switching to BigDecimal return false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala index 7992ba947c069..35ace673fb3da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala @@ -43,7 +43,7 @@ class GenericArrayData(array: Array[Any]) extends ArrayData { override def getDouble(ordinal: Int): Double = getAs(ordinal) - override def getDecimal(ordinal: Int): Decimal = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 4f35b653d73c0..1ad70733eae03 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -242,10 +242,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123)) - checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0)) + checkEvaluation(cast(123L, DecimalType(3, 1)), null) - // TODO: Fix the following bug and re-enable it. - // checkEvaluation(cast(123L, DecimalType(2, 0)), null) + checkEvaluation(cast(123L, DecimalType(2, 0)), null) } test("cast from boolean") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index fd1d6c1d25497..887e43621a941 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Timestamp, Date} +import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.util.Calendar diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 6a907290f2dbe..c6b4c729de2f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -55,13 +55,13 @@ class UnsafeFixedWidthAggregationMapSuite } test("supported schemas") { + assert(supportsAggregationBufferSchema( + StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil))) + assert(!supportsAggregationBufferSchema( + StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil))) assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) - assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil))) - assert( !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) - assert( - !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) } test("empty map") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index b7bc17f89e82f..a0e1701339ea7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -46,7 +46,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getLong(1) === 1) assert(unsafeRow.getInt(2) === 2) - // We can copy UnsafeRows as long as they don't reference ObjectPools val unsafeRowCopy = unsafeRow.copy() assert(unsafeRowCopy.getLong(0) === 0) assert(unsafeRowCopy.getLong(1) === 1) @@ -122,8 +121,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { FloatType, DoubleType, StringType, - BinaryType - // DecimalType.Default, + BinaryType, + DecimalType.USER_DEFAULT // ArrayType(IntegerType) ) val converter = UnsafeProjection.create(fieldTypes) @@ -150,7 +149,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(createdFromNull.getDouble(7) === 0.0d) assert(createdFromNull.getUTF8String(8) === null) assert(createdFromNull.getBinary(9) === null) - // assert(createdFromNull.get(10) === null) + assert(createdFromNull.getDecimal(10, 10, 0) === null) // assert(createdFromNull.get(11) === null) // If we have an UnsafeRow with columns that are initially non-null and we null out those @@ -168,7 +167,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r.setDouble(7, 700) r.update(8, UTF8String.fromString("hello")) r.update(9, "world".getBytes) - // r.update(10, Decimal(10)) + r.setDecimal(10, Decimal(10), 10) // r.update(11, Array(11)) r } @@ -184,7 +183,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.getBinary(9)) - // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.getDecimal(10, 10, 0) === + rowWithNoNullColumns.getDecimal(10, 10, 0)) // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) for (i <- fieldTypes.indices) { @@ -203,7 +203,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { setToNullAfterCreation.setDouble(7, 700) // setToNullAfterCreation.update(8, UTF8String.fromString("hello")) // setToNullAfterCreation.update(9, "world".getBytes) - // setToNullAfterCreation.update(10, Decimal(10)) + setToNullAfterCreation.setDecimal(10, Decimal(10), 10) // setToNullAfterCreation.update(11, Array(11)) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) @@ -216,7 +216,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) // assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) // assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) - // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.getDecimal(10, 10, 0) === + rowWithNoNullColumns.getDecimal(10, 10, 0)) // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 454b7b91a63f5..1620fc401ba6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -114,7 +114,7 @@ private[sql] class FixedDecimalColumnBuilder( precision: Int, scale: Int) extends NativeColumnBuilder( - new FixedDecimalColumnStats, + new FixedDecimalColumnStats(precision, scale), FIXED_DECIMAL(precision, scale)) // TODO (lian) Add support for array, struct and map diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 32a84b2676e07..af1a8ecca9b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -234,14 +234,14 @@ private[sql] class BinaryColumnStats extends ColumnStats { InternalRow(null, null, nullCount, count, sizeInBytes) } -private[sql] class FixedDecimalColumnStats extends ColumnStats { +private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { protected var upper: Decimal = null protected var lower: Decimal = null override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getDecimal(ordinal) + val value = row.getDecimal(ordinal, precision, scale) if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value sizeInBytes += FIXED_DECIMAL.defaultSize diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 2863f6c230a9d..30f8fe320db3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -392,7 +392,7 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) } override def getField(row: InternalRow, ordinal: Int): Decimal = { - row.getDecimal(ordinal) + row.getDecimal(ordinal, precision, scale) } override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index b85aada9d9d4c..d851eae3fcc71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -202,7 +202,7 @@ case class GeneratedAggregate( val schemaSupportsUnsafe: Boolean = { UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) + UnsafeProjection.canSupport(groupKeySchema) } child.execute().mapPartitions { iter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index c808442a4849b..e5bbd0aaed0a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -298,7 +298,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val value = row.getDecimal(i) + val value = row.getDecimal(i, decimal.precision, decimal.scale) val javaBigDecimal = value.toJavaBigDecimal // First, write out the unscaled value. val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 79dd16b7b0c39..ec8da38a3d427 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -293,8 +293,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes)) case BinaryType => writer.addBinary(Binary.fromByteArray(record.getBinary(index))) - case DecimalType.Fixed(precision, _) => - writeDecimal(record.getDecimal(index), precision) + case DecimalType.Fixed(precision, scale) => + writeDecimal(record.getDecimal(index, precision, scale), precision) case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 4499a7207031d..66014ddca0596 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -34,8 +34,7 @@ class ColumnStatsSuite extends SparkFunSuite { testColumnStats(classOf[DoubleColumnStats], DOUBLE, InternalRow(Double.MaxValue, Double.MinValue, 0)) testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) - testColumnStats(classOf[FixedDecimalColumnStats], - FIXED_DECIMAL(15, 10), InternalRow(null, null, 0)) + testDecimalColumnStats(InternalRow(null, null, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], @@ -52,7 +51,7 @@ class ColumnStatsSuite extends SparkFunSuite { } test(s"$columnStatsName: non-empty") { - import ColumnarTestUtils._ + import org.apache.spark.sql.columnar.ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) @@ -73,4 +72,39 @@ class ColumnStatsSuite extends SparkFunSuite { } } } + + def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics: InternalRow) { + + val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName + val columnType = FIXED_DECIMAL(15, 10) + + test(s"$columnStatsName: empty") { + val columnStats = new FixedDecimalColumnStats(15, 10) + columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + case (actual, expected) => assert(actual === expected) + } + } + + test(s"$columnStatsName: non-empty") { + import org.apache.spark.sql.columnar.ColumnarTestUtils._ + + val columnStats = new FixedDecimalColumnStats(15, 10) + val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) + rows.foreach(columnStats.gatherStats(_, 0)) + + val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType]) + val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] + val stats = columnStats.collectedStatistics + + assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1)) + assertResult(10, "Wrong null count")(stats.genericGet(2)) + assertResult(20, "Wrong row count")(stats.genericGet(3)) + assertResult(stats.genericGet(4), "Wrong size in bytes") { + rows.map { row => + if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) + }.sum + } + } + } } From 351eda0e2fd47c183c4298469970032097ad07a0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 Jul 2015 17:22:51 -0700 Subject: [PATCH 48/50] [SPARK-6319][SQL] Throw AnalysisException when using BinaryType on Join and Aggregate JIRA: https://issues.apache.org/jira/browse/SPARK-6319 Spark SQL uses plain byte arrays to represent binary values. However, the arrays are compared by reference rather than by values. Thus, we should not use BinaryType on Join and Aggregate in current implementation. Author: Liang-Chi Hsieh Closes #7787 from viirya/agg_no_binary_type and squashes the following commits: 4f76cac [Liang-Chi Hsieh] Throw AnalysisException when using BinaryType on Join and Aggregate. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 20 +++++++++++++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 11 +++++++++- .../org/apache/spark/sql/JoinSuite.scala | 9 +++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a373714832962..0ebc3d180a780 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -87,6 +87,18 @@ trait CheckAnalysis { s"join condition '${condition.prettyString}' " + s"of type ${condition.dataType.simpleString} is not a boolean.") + case j @ Join(_, _, _, Some(condition)) => + def checkValidJoinConditionExprs(expr: Expression): Unit = expr match { + case p: Predicate => + p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs) + case e if e.dataType.isInstanceOf[BinaryType] => + failAnalysis(s"expression ${e.prettyString} in join condition " + + s"'${condition.prettyString}' can't be binary type.") + case _ => // OK + } + + checkValidJoinConditionExprs(condition) + case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK @@ -100,7 +112,15 @@ trait CheckAnalysis { case e => e.children.foreach(checkValidAggregateExpression) } + def checkValidGroupingExprs(expr: Expression): Unit = expr.dataType match { + case BinaryType => + failAnalysis(s"grouping expression '${expr.prettyString}' in aggregate can " + + s"not be binary type.") + case _ => // OK + } + aggregateExprs.foreach(checkValidAggregateExpression) + aggregateExprs.foreach(checkValidGroupingExprs) case Sort(orders, _, _) => orders.foreach { order => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b26d3ab253a1d..228ece8065151 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.{BinaryType, DecimalType} class DataFrameAggregateSuite extends QueryTest { @@ -191,4 +191,13 @@ class DataFrameAggregateSuite extends QueryTest { Row(null)) } + test("aggregation can't work on binary type") { + val df = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("c").select($"c" cast BinaryType) + intercept[AnalysisException] { + df.groupBy("c").agg(count("*")) + } + intercept[AnalysisException] { + df.distinct + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 666f26bf620e1..27c08f64649ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.types.BinaryType class JoinSuite extends QueryTest with BeforeAndAfterEach { @@ -489,4 +490,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(3, 2) :: Nil) } + + test("Join can't work on binary type") { + val left = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("c").select($"c" cast BinaryType) + val right = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("d").select($"d" cast BinaryType) + intercept[AnalysisException] { + left.join(right, ($"left.N" === $"right.N"), "full") + } + } } From 65fa4181c35135080870c1e4c1f904ada3a8cf59 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 30 Jul 2015 17:26:18 -0700 Subject: [PATCH 49/50] [SPARK-9077] [MLLIB] Improve error message for decision trees when numExamples < maxCategoriesPerFeature Improve error message when number of examples is less than arity of high-arity categorical feature CC jkbradley is this about what you had in mind? I know it's a starter, but was on my list to close out in the short term. Author: Sean Owen Closes #7800 from srowen/SPARK-9077 and squashes the following commits: b8f6cdb [Sean Owen] Improve error message when number of examples is less than arity of high-arity categorical feature --- .../spark/mllib/tree/impl/DecisionTreeMetadata.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 380291ac22bd3..9fe264656ede7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -128,9 +128,13 @@ private[spark] object DecisionTreeMetadata extends Logging { // based on the number of training examples. if (strategy.categoricalFeaturesInfo.nonEmpty) { val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max + val maxCategory = + strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1 require(maxCategoriesPerFeature <= maxPossibleBins, - s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " + - s"in categorical features (= $maxCategoriesPerFeature)") + s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " + + s"number of values in each categorical feature, but categorical feature $maxCategory " + + s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " + + "features with a large number of values, or add more training examples.") } val unorderedFeatures = new mutable.HashSet[Int]() From 3c66ff727d4b47220e1ff363cea215189ed64f36 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 30 Jul 2015 17:38:48 -0700 Subject: [PATCH 50/50] [SPARK-9489] Remove unnecessary compatibility and requirements checks from Exchange While reviewing yhuai's patch for SPARK-2205 (#7773), I noticed that Exchange's `compatible` check may be incorrectly returning `false` in many cases. As far as I know, this is not actually a problem because the `compatible`, `meetsRequirements`, and `needsAnySort` checks are serving only as short-circuit performance optimizations that are not necessary for correctness. In order to reduce code complexity, I think that we should remove these checks and unconditionally rewrite the operator's children. This should be safe because we rewrite the tree in a single bottom-up pass. Author: Josh Rosen Closes #7807 from JoshRosen/SPARK-9489 and squashes the following commits: 9d76ce9 [Josh Rosen] [SPARK-9489] Remove compatibleWith, meetsRequirements, and needsAnySort checks from Exchange --- .../plans/physical/partitioning.scala | 35 --------- .../apache/spark/sql/execution/Exchange.scala | 76 +++++-------------- 2 files changed, 17 insertions(+), 94 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 2dcfa19fec383..f4d1dbaf28efe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -86,14 +86,6 @@ sealed trait Partitioning { */ def satisfies(required: Distribution): Boolean - /** - * Returns true iff all distribution guarantees made by this partitioning can also be made - * for the `other` specified partitioning. - * For example, two [[HashPartitioning HashPartitioning]]s are - * only compatible if the `numPartitions` of them is the same. - */ - def compatibleWith(other: Partitioning): Boolean - /** Returns the expressions that are used to key the partitioning. */ def keyExpressions: Seq[Expression] } @@ -104,11 +96,6 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case _ => false } - override def compatibleWith(other: Partitioning): Boolean = other match { - case UnknownPartitioning(_) => true - case _ => false - } - override def keyExpressions: Seq[Expression] = Nil } @@ -117,11 +104,6 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning): Boolean = other match { - case SinglePartition => true - case _ => false - } - override def keyExpressions: Seq[Expression] = Nil } @@ -130,11 +112,6 @@ case object BroadcastPartitioning extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning): Boolean = other match { - case SinglePartition => true - case _ => false - } - override def keyExpressions: Seq[Expression] = Nil } @@ -159,12 +136,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning): Boolean = other match { - case BroadcastPartitioning => true - case h: HashPartitioning if h == this => true - case _ => false - } - override def keyExpressions: Seq[Expression] = expressions } @@ -199,11 +170,5 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning): Boolean = other match { - case BroadcastPartitioning => true - case r: RangePartitioning if r == this => true - case _ => false - } - override def keyExpressions: Seq[Expression] = ordering.map(_.child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 70e5031fb63c0..6bd57f010a990 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -202,41 +202,6 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => - // True iff every child's outputPartitioning satisfies the corresponding - // required data distribution. - def meetsRequirements: Boolean = - operator.requiredChildDistribution.zip(operator.children).forall { - case (required, child) => - val valid = child.outputPartitioning.satisfies(required) - logDebug( - s"${if (valid) "Valid" else "Invalid"} distribution," + - s"required: $required current: ${child.outputPartitioning}") - valid - } - - // True iff any of the children are incorrectly sorted. - def needsAnySort: Boolean = - operator.requiredChildOrdering.zip(operator.children).exists { - case (required, child) => required.nonEmpty && required != child.outputOrdering - } - - // True iff outputPartitionings of children are compatible with each other. - // It is possible that every child satisfies its required data distribution - // but two children have incompatible outputPartitionings. For example, - // A dataset is range partitioned by "a.asc" (RangePartitioning) and another - // dataset is hash partitioned by "a" (HashPartitioning). Tuples in these two - // datasets are both clustered by "a", but these two outputPartitionings are not - // compatible. - // TODO: ASSUMES TRANSITIVITY? - def compatible: Boolean = - operator.children - .map(_.outputPartitioning) - .sliding(2) - .forall { - case Seq(a) => true - case Seq(a, b) => a.compatibleWith(b) - } - // Adds Exchange or Sort operators as required def addOperatorsIfNecessary( partitioning: Partitioning, @@ -269,33 +234,26 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ addSortIfNecessary(addShuffleIfNecessary(child)) } - if (meetsRequirements && compatible && !needsAnySort) { - operator - } else { - // At least one child does not satisfies its required data distribution or - // at least one child's outputPartitioning is not compatible with another child's - // outputPartitioning. In this case, we need to add Exchange operators. - val requirements = - (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) + val requirements = + (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) - val fixedChildren = requirements.zipped.map { - case (AllTuples, rowOrdering, child) => - addOperatorsIfNecessary(SinglePartition, rowOrdering, child) - case (ClusteredDistribution(clustering), rowOrdering, child) => - addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) - case (OrderedDistribution(ordering), rowOrdering, child) => - addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) + val fixedChildren = requirements.zipped.map { + case (AllTuples, rowOrdering, child) => + addOperatorsIfNecessary(SinglePartition, rowOrdering, child) + case (ClusteredDistribution(clustering), rowOrdering, child) => + addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) + case (OrderedDistribution(ordering), rowOrdering, child) => + addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) - case (UnspecifiedDistribution, Seq(), child) => - child - case (UnspecifiedDistribution, rowOrdering, child) => - sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) + case (UnspecifiedDistribution, Seq(), child) => + child + case (UnspecifiedDistribution, rowOrdering, child) => + sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) - case (dist, ordering, _) => - sys.error(s"Don't know how to ensure $dist with ordering $ordering") - } - - operator.withNewChildren(fixedChildren) + case (dist, ordering, _) => + sys.error(s"Don't know how to ensure $dist with ordering $ordering") } + + operator.withNewChildren(fixedChildren) } }