diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py
new file mode 100644
index 0000000000000..f0ca97c724940
--- /dev/null
+++ b/examples/src/main/python/ml/cross_validator.py
@@ -0,0 +1,96 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+from pyspark import SparkContext
+from pyspark.ml import Pipeline
+from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.evaluation import BinaryClassificationEvaluator
+from pyspark.ml.feature import HashingTF, Tokenizer
+from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
+from pyspark.sql import Row, SQLContext
+
+"""
+A simple example demonstrating model selection using CrossValidator.
+This example also demonstrates how Pipelines are Estimators.
+Run with:
+
+ bin/spark-submit examples/src/main/python/ml/cross_validator.py
+"""
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="CrossValidatorExample")
+ sqlContext = SQLContext(sc)
+
+ # Prepare training documents, which are labeled.
+ LabeledDocument = Row("id", "text", "label")
+ training = sc.parallelize([(0, "a b c d e spark", 1.0),
+ (1, "b d", 0.0),
+ (2, "spark f g h", 1.0),
+ (3, "hadoop mapreduce", 0.0),
+ (4, "b spark who", 1.0),
+ (5, "g d a y", 0.0),
+ (6, "spark fly", 1.0),
+ (7, "was mapreduce", 0.0),
+ (8, "e spark program", 1.0),
+ (9, "a e c l", 0.0),
+ (10, "spark compile", 1.0),
+ (11, "hadoop software", 0.0)
+ ]) \
+ .map(lambda x: LabeledDocument(*x)).toDF()
+
+ # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
+ tokenizer = Tokenizer(inputCol="text", outputCol="words")
+ hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
+ lr = LogisticRegression(maxIter=10)
+ pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
+
+ # We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
+ # This will allow us to jointly choose parameters for all Pipeline stages.
+ # A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
+ # We use a ParamGridBuilder to construct a grid of parameters to search over.
+ # With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
+ # this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
+ paramGrid = ParamGridBuilder() \
+ .addGrid(hashingTF.numFeatures, [10, 100, 1000]) \
+ .addGrid(lr.regParam, [0.1, 0.01]) \
+ .build()
+
+ crossval = CrossValidator(estimator=pipeline,
+ estimatorParamMaps=paramGrid,
+ evaluator=BinaryClassificationEvaluator(),
+ numFolds=2) # use 3+ folds in practice
+
+ # Run cross-validation, and choose the best set of parameters.
+ cvModel = crossval.fit(training)
+
+ # Prepare test documents, which are unlabeled.
+ Document = Row("id", "text")
+ test = sc.parallelize([(4L, "spark i j k"),
+ (5L, "l m n"),
+ (6L, "mapreduce spark"),
+ (7L, "apache hadoop")]) \
+ .map(lambda x: Document(*x)).toDF()
+
+ # Make predictions on test documents. cvModel uses the best model found (lrModel).
+ prediction = cvModel.transform(test)
+ selected = prediction.select("id", "text", "probability", "prediction")
+ for row in selected.collect():
+ print(row)
+
+ sc.stop()
diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py
index 3933d59b52cd1..a9f29dab2d602 100644
--- a/examples/src/main/python/ml/simple_params_example.py
+++ b/examples/src/main/python/ml/simple_params_example.py
@@ -41,8 +41,8 @@
# prepare training data.
# We create an RDD of LabeledPoints and convert them into a DataFrame.
- # Spark DataFrames can automatically infer the schema from named tuples
- # and LabeledPoint implements __reduce__ to behave like a named tuple.
+ # A LabeledPoint is an Object with two fields named label and features
+ # and Spark SQL identifies these fields and creates the schema appropriately.
training = sc.parallelize([
LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])),
LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])),
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala
new file mode 100644
index 0000000000000..b54466fd48bc5
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
+import org.apache.spark.sql.DataFrame
+
+/**
+ * An example runner for linear regression with elastic-net (mixing L1/L2) regularization.
+ * Run with
+ * {{{
+ * bin/run-example ml.LinearRegressionExample [options]
+ * }}}
+ * A synthetic dataset can be found at `data/mllib/sample_linear_regression_data.txt` which can be
+ * trained by
+ * {{{
+ * bin/run-example ml.LinearRegressionExample --regParam 0.15 --elasticNetParam 1.0 \
+ * data/mllib/sample_linear_regression_data.txt
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object LinearRegressionExample {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ regParam: Double = 0.0,
+ elasticNetParam: Double = 0.0,
+ maxIter: Int = 100,
+ tol: Double = 1E-6,
+ fracTest: Double = 0.2) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("LinearRegressionExample") {
+ head("LinearRegressionExample: an example Linear Regression with Elastic-Net app.")
+ opt[Double]("regParam")
+ .text(s"regularization parameter, default: ${defaultParams.regParam}")
+ .action((x, c) => c.copy(regParam = x))
+ opt[Double]("elasticNetParam")
+ .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " +
+ s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " +
+ s"L1 and L2, default: ${defaultParams.elasticNetParam}")
+ .action((x, c) => c.copy(elasticNetParam = x))
+ opt[Int]("maxIter")
+ .text(s"maximum number of iterations, default: ${defaultParams.maxIter}")
+ .action((x, c) => c.copy(maxIter = x))
+ opt[Double]("tol")
+ .text(s"the convergence tolerance of iterations, Smaller value will lead " +
+ s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}")
+ .action((x, c) => c.copy(tol = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("dataFormat")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"LinearRegressionExample with $params")
+ val sc = new SparkContext(conf)
+
+ println(s"LinearRegressionExample with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, "regression", params.fracTest)
+
+ val lir = new LinearRegression()
+ .setFeaturesCol("features")
+ .setLabelCol("label")
+ .setRegParam(params.regParam)
+ .setElasticNetParam(params.elasticNetParam)
+ .setMaxIter(params.maxIter)
+ .setTol(params.tol)
+
+ // Train the model
+ val startTime = System.nanoTime()
+ val lirModel = lir.fit(training)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+
+ // Print the weights and intercept for linear regression.
+ println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}")
+
+ println("Training data results:")
+ DecisionTreeExample.evaluateRegressionModel(lirModel, training, "label")
+ println("Test data results:")
+ DecisionTreeExample.evaluateRegressionModel(lirModel, test, "label")
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala
new file mode 100644
index 0000000000000..b12f833ce94c8
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
+import org.apache.spark.ml.feature.StringIndexer
+import org.apache.spark.sql.DataFrame
+
+/**
+ * An example runner for logistic regression with elastic-net (mixing L1/L2) regularization.
+ * Run with
+ * {{{
+ * bin/run-example ml.LogisticRegressionExample [options]
+ * }}}
+ * A synthetic dataset can be found at `data/mllib/sample_libsvm_data.txt` which can be
+ * trained by
+ * {{{
+ * bin/run-example ml.LogisticRegressionExample --regParam 0.3 --elasticNetParam 0.8 \
+ * data/mllib/sample_libsvm_data.txt
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object LogisticRegressionExample {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ regParam: Double = 0.0,
+ elasticNetParam: Double = 0.0,
+ maxIter: Int = 100,
+ fitIntercept: Boolean = true,
+ tol: Double = 1E-6,
+ fracTest: Double = 0.2) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("LogisticRegressionExample") {
+ head("LogisticRegressionExample: an example Logistic Regression with Elastic-Net app.")
+ opt[Double]("regParam")
+ .text(s"regularization parameter, default: ${defaultParams.regParam}")
+ .action((x, c) => c.copy(regParam = x))
+ opt[Double]("elasticNetParam")
+ .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " +
+ s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " +
+ s"L1 and L2, default: ${defaultParams.elasticNetParam}")
+ .action((x, c) => c.copy(elasticNetParam = x))
+ opt[Int]("maxIter")
+ .text(s"maximum number of iterations, default: ${defaultParams.maxIter}")
+ .action((x, c) => c.copy(maxIter = x))
+ opt[Boolean]("fitIntercept")
+ .text(s"whether to fit an intercept term, default: ${defaultParams.fitIntercept}")
+ .action((x, c) => c.copy(fitIntercept = x))
+ opt[Double]("tol")
+ .text(s"the convergence tolerance of iterations, Smaller value will lead " +
+ s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}")
+ .action((x, c) => c.copy(tol = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("dataFormat")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"LogisticRegressionExample with $params")
+ val sc = new SparkContext(conf)
+
+ println(s"LogisticRegressionExample with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, "classification", params.fracTest)
+
+ // Set up Pipeline
+ val stages = new mutable.ArrayBuffer[PipelineStage]()
+
+ val labelIndexer = new StringIndexer()
+ .setInputCol("labelString")
+ .setOutputCol("indexedLabel")
+ stages += labelIndexer
+
+ val lor = new LogisticRegression()
+ .setFeaturesCol("features")
+ .setLabelCol("indexedLabel")
+ .setRegParam(params.regParam)
+ .setElasticNetParam(params.elasticNetParam)
+ .setMaxIter(params.maxIter)
+ .setTol(params.tol)
+
+ stages += lor
+ val pipeline = new Pipeline().setStages(stages.toArray)
+
+ // Fit the Pipeline
+ val startTime = System.nanoTime()
+ val pipelineModel = pipeline.fit(training)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+
+ val lirModel = pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel]
+ // Print the weights and intercept for logistic regression.
+ println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}")
+
+ println("Training data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, "indexedLabel")
+ println("Test data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, "indexedLabel")
+
+ sc.stop()
+ }
+}
diff --git a/launcher/pom.xml b/launcher/pom.xml
index ebfa7685eaa18..cc177d23dff77 100644
--- a/launcher/pom.xml
+++ b/launcher/pom.xml
@@ -29,7 +29,7 @@
org.apache.sparkspark-launcher_2.10jar
- Spark Launcher Project
+ Spark Project Launcherhttp://spark.apache.org/launcher
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index d13109d9da4c0..f136bcee9cf2b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -74,7 +74,7 @@ class LogisticRegression(override val uid: String)
setDefault(elasticNetParam -> 0.0)
/**
- * Set the maximal number of iterations.
+ * Set the maximum number of iterations.
* Default is 100.
* @group setParam
*/
@@ -90,7 +90,11 @@ class LogisticRegression(override val uid: String)
def setTol(value: Double): this.type = set(tol, value)
setDefault(tol -> 1E-6)
- /** @group setParam */
+ /**
+ * Whether to fit an intercept term.
+ * Default is true.
+ * @group setParam
+ * */
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
setDefault(fitIntercept -> true)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 7b726da388075..825f9ed1b54b2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -131,6 +131,7 @@ final class OneVsRestModel private[ml] (
// output label and label metadata as prediction
val labelUdf = callUDF(label, DoubleType, col(accColName))
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
+ .drop(accColName)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
index 3ae1833390152..1e758cb775de7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
@@ -41,7 +41,7 @@ class ElementwiseProduct(override val uid: String)
* the vector to multiply with input vectors
* @group param
*/
- val scalingVec: Param[Vector] = new Param(this, "scalingVector", "vector for hadamard product")
+ val scalingVec: Param[Vector] = new Param(this, "scalingVec", "vector for hadamard product")
/** @group setParam */
def setScalingVec(value: Vector): this.type = set(scalingVec, value)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 1ffb5eddc36bd..8ffbcf0d8bc71 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -33,7 +33,7 @@ private[shared] object SharedParamsCodeGen {
val params = Seq(
ParamDesc[Double]("regParam", "regularization parameter (>= 0)",
isValid = "ParamValidators.gtEq(0)"),
- ParamDesc[Int]("maxIter", "max number of iterations (>= 0)",
+ ParamDesc[Int]("maxIter", "maximum number of iterations (>= 0)",
isValid = "ParamValidators.gtEq(0)"),
ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")),
ParamDesc[String]("labelCol", "label column name", Some("\"label\"")),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index ed08417bd4df8..a0c8ccdac9ad9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -45,10 +45,10 @@ private[ml] trait HasRegParam extends Params {
private[ml] trait HasMaxIter extends Params {
/**
- * Param for max number of iterations (>= 0).
+ * Param for maximum number of iterations (>= 0).
* @group param
*/
- final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0))
+ final val maxIter: IntParam = new IntParam(this, "maxIter", "maximum number of iterations (>= 0)", ParamValidators.gtEq(0))
/** @group getParam */
final def getMaxIter: Int = $(maxIter)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index fe2a71a331694..70cd8e9e87fae 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -83,7 +83,7 @@ class LinearRegression(override val uid: String)
setDefault(elasticNetParam -> 0.0)
/**
- * Set the maximal number of iterations.
+ * Set the maximum number of iterations.
* Default is 100.
* @group setParam
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala
index b0985baf9b278..d67fe6c3ee4f8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala
@@ -25,10 +25,10 @@ import org.apache.spark.mllib.linalg._
* Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a
* provided "weight" vector. In other words, it scales each column of the dataset by a scalar
* multiplier.
- * @param scalingVector The values used to scale the reference vector's individual components.
+ * @param scalingVec The values used to scale the reference vector's individual components.
*/
@Experimental
-class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer {
+class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer {
/**
* Does the hadamard product transformation.
@@ -37,15 +37,15 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer {
* @return transformed vector.
*/
override def transform(vector: Vector): Vector = {
- require(vector.size == scalingVector.size,
- s"vector sizes do not match: Expected ${scalingVector.size} but found ${vector.size}")
+ require(vector.size == scalingVec.size,
+ s"vector sizes do not match: Expected ${scalingVec.size} but found ${vector.size}")
vector match {
case dv: DenseVector =>
val values: Array[Double] = dv.values.clone()
- val dim = scalingVector.size
+ val dim = scalingVec.size
var i = 0
while (i < dim) {
- values(i) *= scalingVector(i)
+ values(i) *= scalingVec(i)
i += 1
}
Vectors.dense(values)
@@ -54,7 +54,7 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer {
val dim = values.length
var i = 0
while (i < dim) {
- values(i) *= scalingVector(indices(i))
+ values(i) *= scalingVec(indices(i))
i += 1
}
Vectors.sparse(size, indices, values)
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
index da2218056307e..599e9cfd23ad4 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
@@ -55,9 +55,9 @@ public void tearDown() {
@Test
public void hashingTF() {
JavaRDD jrdd = jsc.parallelize(Lists.newArrayList(
- RowFactory.create(0, "Hi I heard about Spark"),
- RowFactory.create(0, "I wish Java could use case classes"),
- RowFactory.create(1, "Logistic regression models are neat")
+ RowFactory.create(0.0, "Hi I heard about Spark"),
+ RowFactory.create(0.0, "I wish Java could use case classes"),
+ RowFactory.create(1.0, "Logistic regression models are neat")
));
StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index f439f3261f06f..1d04ccb509057 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -93,6 +93,15 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
ova.fit(datasetWithLabelMetadata)
}
+
+ test("SPARK-8049: OneVsRest shouldn't output temp columns") {
+ val logReg = new LogisticRegression()
+ .setMaxIter(1)
+ val ovr = new OneVsRest()
+ .setClassifier(logReg)
+ val output = ovr.fit(dataset).transform(dataset)
+ assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
+ }
}
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index f80e7749098a5..96094d7a099aa 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -27,7 +27,7 @@ class ParamsSuite extends SparkFunSuite {
import solver.{maxIter, inputCol}
assert(maxIter.name === "maxIter")
- assert(maxIter.doc === "max number of iterations (>= 0)")
+ assert(maxIter.doc === "maximum number of iterations (>= 0)")
assert(maxIter.parent === uid)
assert(maxIter.toString === s"${uid}__maxIter")
assert(!maxIter.isValid(-1))
@@ -36,7 +36,7 @@ class ParamsSuite extends SparkFunSuite {
solver.setMaxIter(5)
assert(solver.explainParam(maxIter) ===
- "maxIter: max number of iterations (>= 0) (default: 10, current: 5)")
+ "maxIter: maximum number of iterations (>= 0) (default: 10, current: 5)")
assert(inputCol.toString === s"${uid}__inputCol")
@@ -120,7 +120,7 @@ class ParamsSuite extends SparkFunSuite {
intercept[NoSuchElementException](solver.getInputCol)
assert(solver.explainParam(maxIter) ===
- "maxIter: max number of iterations (>= 0) (default: 10, current: 100)")
+ "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)")
assert(solver.explainParams() ===
Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n"))
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 8dc5039f587f0..1ecec5b126505 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -315,6 +315,14 @@ def between(self, lowerBound, upperBound):
"""
A boolean expression that is evaluated to true if the value of this
expression is between the given columns.
+
+ >>> df.select(df.name, df.age.between(2, 4)).show()
+ +-----+--------------------------+
+ | name|((age >= 2) && (age <= 4))|
+ +-----+--------------------------+
+ |Alice| true|
+ | Bob| false|
+ +-----+--------------------------+
"""
return (self >= lowerBound) & (self <= upperBound)
@@ -328,12 +336,20 @@ def when(self, condition, value):
:param condition: a boolean :class:`Column` expression.
:param value: a literal value, or a :class:`Column` expression.
+
+ >>> from pyspark.sql import functions as F
+ >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show()
+ +-----+--------------------------------------------------------+
+ | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0|
+ +-----+--------------------------------------------------------+
+ |Alice| -1|
+ | Bob| 1|
+ +-----+--------------------------------------------------------+
"""
- sc = SparkContext._active_spark_context
if not isinstance(condition, Column):
raise TypeError("condition should be a Column")
v = value._jc if isinstance(value, Column) else value
- jc = sc._jvm.functions.when(condition._jc, v)
+ jc = self._jc.when(condition._jc, v)
return Column(jc)
@since(1.4)
@@ -345,9 +361,18 @@ def otherwise(self, value):
See :func:`pyspark.sql.functions.when` for example usage.
:param value: a literal value, or a :class:`Column` expression.
+
+ >>> from pyspark.sql import functions as F
+ >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show()
+ +-----+---------------------------------+
+ | name|CASE WHEN (age > 3) THEN 1 ELSE 0|
+ +-----+---------------------------------+
+ |Alice| 0|
+ | Bob| 1|
+ +-----+---------------------------------+
"""
v = value._jc if isinstance(value, Column) else value
- jc = self._jc.otherwise(value)
+ jc = self._jc.otherwise(v)
return Column(jc)
@since(1.4)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 1c0ddb5093d17..2e7b4c236d8f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -18,7 +18,10 @@
package org.apache.spark.sql.catalyst
import java.lang.{Iterable => JavaIterable}
+import java.math.{BigDecimal => JavaBigDecimal}
+import java.sql.Date
import java.util.{Map => JavaMap}
+import javax.annotation.Nullable
import scala.collection.mutable.HashMap
@@ -34,197 +37,338 @@ object CatalystTypeConverters {
// Since the map values can be mutable, we explicitly import scala.collection.Map at here.
import scala.collection.Map
+ private def isPrimitive(dataType: DataType): Boolean = {
+ dataType match {
+ case BooleanType => true
+ case ByteType => true
+ case ShortType => true
+ case IntegerType => true
+ case LongType => true
+ case FloatType => true
+ case DoubleType => true
+ case _ => false
+ }
+ }
+
+ private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = {
+ val converter = dataType match {
+ case udt: UserDefinedType[_] => UDTConverter(udt)
+ case arrayType: ArrayType => ArrayConverter(arrayType.elementType)
+ case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType)
+ case structType: StructType => StructConverter(structType)
+ case StringType => StringConverter
+ case DateType => DateConverter
+ case dt: DecimalType => BigDecimalConverter
+ case BooleanType => BooleanConverter
+ case ByteType => ByteConverter
+ case ShortType => ShortConverter
+ case IntegerType => IntConverter
+ case LongType => LongConverter
+ case FloatType => FloatConverter
+ case DoubleType => DoubleConverter
+ case _ => IdentityConverter
+ }
+ converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]]
+ }
+
/**
- * Converts Scala objects to catalyst rows / types. This method is slow, and for batch
- * conversion you should be using converter produced by createToCatalystConverter.
- * Note: This is always called after schemaFor has been called.
- * This ordering is important for UDT registration.
+ * Converts a Scala type to its Catalyst equivalent (and vice versa).
+ *
+ * @tparam ScalaInputType The type of Scala values that can be converted to Catalyst.
+ * @tparam ScalaOutputType The type of Scala values returned when converting Catalyst to Scala.
+ * @tparam CatalystType The internal Catalyst type used to represent values of this Scala type.
*/
- def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match {
- // Check UDT first since UDTs can override other types
- case (obj, udt: UserDefinedType[_]) =>
- udt.serialize(obj)
-
- case (o: Option[_], _) =>
- o.map(convertToCatalyst(_, dataType)).orNull
-
- case (s: Seq[_], arrayType: ArrayType) =>
- s.map(convertToCatalyst(_, arrayType.elementType))
-
- case (jit: JavaIterable[_], arrayType: ArrayType) => {
- val iter = jit.iterator
- var listOfItems: List[Any] = List()
- while (iter.hasNext) {
- val item = iter.next()
- listOfItems :+= convertToCatalyst(item, arrayType.elementType)
+ private abstract class CatalystTypeConverter[ScalaInputType, ScalaOutputType, CatalystType]
+ extends Serializable {
+
+ /**
+ * Converts a Scala type to its Catalyst equivalent while automatically handling nulls
+ * and Options.
+ */
+ final def toCatalyst(@Nullable maybeScalaValue: Any): CatalystType = {
+ if (maybeScalaValue == null) {
+ null.asInstanceOf[CatalystType]
+ } else if (maybeScalaValue.isInstanceOf[Option[ScalaInputType]]) {
+ val opt = maybeScalaValue.asInstanceOf[Option[ScalaInputType]]
+ if (opt.isDefined) {
+ toCatalystImpl(opt.get)
+ } else {
+ null.asInstanceOf[CatalystType]
+ }
+ } else {
+ toCatalystImpl(maybeScalaValue.asInstanceOf[ScalaInputType])
}
- listOfItems
}
- case (s: Array[_], arrayType: ArrayType) =>
- s.toSeq.map(convertToCatalyst(_, arrayType.elementType))
+ /**
+ * Given a Catalyst row, convert the value at column `column` to its Scala equivalent.
+ */
+ final def toScala(row: Row, column: Int): ScalaOutputType = {
+ if (row.isNullAt(column)) null.asInstanceOf[ScalaOutputType] else toScalaImpl(row, column)
+ }
+
+ /**
+ * Convert a Catalyst value to its Scala equivalent.
+ */
+ def toScala(@Nullable catalystValue: CatalystType): ScalaOutputType
+
+ /**
+ * Converts a Scala value to its Catalyst equivalent.
+ * @param scalaValue the Scala value, guaranteed not to be null.
+ * @return the Catalyst value.
+ */
+ protected def toCatalystImpl(scalaValue: ScalaInputType): CatalystType
+
+ /**
+ * Given a Catalyst row, convert the value at column `column` to its Scala equivalent.
+ * This method will only be called on non-null columns.
+ */
+ protected def toScalaImpl(row: Row, column: Int): ScalaOutputType
+ }
- case (m: Map[_, _], mapType: MapType) =>
- m.map { case (k, v) =>
- convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
- }
+ private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] {
+ override def toCatalystImpl(scalaValue: Any): Any = scalaValue
+ override def toScala(catalystValue: Any): Any = catalystValue
+ override def toScalaImpl(row: Row, column: Int): Any = row(column)
+ }
- case (jmap: JavaMap[_, _], mapType: MapType) =>
- val iter = jmap.entrySet.iterator
- var listOfEntries: List[(Any, Any)] = List()
- while (iter.hasNext) {
- val entry = iter.next()
- listOfEntries :+= (convertToCatalyst(entry.getKey, mapType.keyType),
- convertToCatalyst(entry.getValue, mapType.valueType))
+ private case class UDTConverter(
+ udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] {
+ override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue)
+ override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue)
+ override def toScalaImpl(row: Row, column: Int): Any = toScala(row(column))
+ }
+
+ /** Converter for arrays, sequences, and Java iterables. */
+ private case class ArrayConverter(
+ elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] {
+
+ private[this] val elementConverter = getConverterForType(elementType)
+
+ override def toCatalystImpl(scalaValue: Any): Seq[Any] = {
+ scalaValue match {
+ case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst)
+ case s: Seq[_] => s.map(elementConverter.toCatalyst)
+ case i: JavaIterable[_] =>
+ val iter = i.iterator
+ var convertedIterable: List[Any] = List()
+ while (iter.hasNext) {
+ val item = iter.next()
+ convertedIterable :+= elementConverter.toCatalyst(item)
+ }
+ convertedIterable
}
- listOfEntries.toMap
-
- case (p: Product, structType: StructType) =>
- val ar = new Array[Any](structType.size)
- val iter = p.productIterator
- var idx = 0
- while (idx < structType.size) {
- ar(idx) = convertToCatalyst(iter.next(), structType.fields(idx).dataType)
- idx += 1
+ }
+
+ override def toScala(catalystValue: Seq[Any]): Seq[Any] = {
+ if (catalystValue == null) {
+ null
+ } else {
+ catalystValue.asInstanceOf[Seq[_]].map(elementConverter.toScala)
}
- new GenericRowWithSchema(ar, structType)
+ }
- case (d: String, _) =>
- UTF8String(d)
+ override def toScalaImpl(row: Row, column: Int): Seq[Any] =
+ toScala(row(column).asInstanceOf[Seq[Any]])
+ }
+
+ private case class MapConverter(
+ keyType: DataType,
+ valueType: DataType)
+ extends CatalystTypeConverter[Any, Map[Any, Any], Map[Any, Any]] {
- case (d: BigDecimal, _) =>
- Decimal(d)
+ private[this] val keyConverter = getConverterForType(keyType)
+ private[this] val valueConverter = getConverterForType(valueType)
- case (d: java.math.BigDecimal, _) =>
- Decimal(d)
+ override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match {
+ case m: Map[_, _] =>
+ m.map { case (k, v) =>
+ keyConverter.toCatalyst(k) -> valueConverter.toCatalyst(v)
+ }
- case (d: java.sql.Date, _) =>
- DateUtils.fromJavaDate(d)
+ case jmap: JavaMap[_, _] =>
+ val iter = jmap.entrySet.iterator
+ val convertedMap: HashMap[Any, Any] = HashMap()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val key = keyConverter.toCatalyst(entry.getKey)
+ convertedMap(key) = valueConverter.toCatalyst(entry.getValue)
+ }
+ convertedMap
+ }
- case (r: Row, structType: StructType) =>
- val converters = structType.fields.map {
- f => (item: Any) => convertToCatalyst(item, f.dataType)
+ override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = {
+ if (catalystValue == null) {
+ null
+ } else {
+ catalystValue.map { case (k, v) =>
+ keyConverter.toScala(k) -> valueConverter.toScala(v)
+ }
}
- convertRowWithConverters(r, structType, converters)
+ }
- case (other, _) =>
- other
+ override def toScalaImpl(row: Row, column: Int): Map[Any, Any] =
+ toScala(row(column).asInstanceOf[Map[Any, Any]])
}
- /**
- * Creates a converter function that will convert Scala objects to the specified catalyst type.
- * Typical use case would be converting a collection of rows that have the same schema. You will
- * call this function once to get a converter, and apply it to every row.
- */
- private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = {
- def extractOption(item: Any): Any = item match {
- case opt: Option[_] => opt.orNull
- case other => other
- }
+ private case class StructConverter(
+ structType: StructType) extends CatalystTypeConverter[Any, Row, Row] {
- dataType match {
- // Check UDT first since UDTs can override other types
- case udt: UserDefinedType[_] =>
- (item) => extractOption(item) match {
- case null => null
- case other => udt.serialize(other)
- }
+ private[this] val converters = structType.fields.map { f => getConverterForType(f.dataType) }
- case arrayType: ArrayType =>
- val elementConverter = createToCatalystConverter(arrayType.elementType)
- (item: Any) => {
- extractOption(item) match {
- case a: Array[_] => a.toSeq.map(elementConverter)
- case s: Seq[_] => s.map(elementConverter)
- case i: JavaIterable[_] => {
- val iter = i.iterator
- var convertedIterable: List[Any] = List()
- while (iter.hasNext) {
- val item = iter.next()
- convertedIterable :+= elementConverter(item)
- }
- convertedIterable
- }
- case null => null
- }
+ override def toCatalystImpl(scalaValue: Any): Row = scalaValue match {
+ case row: Row =>
+ val ar = new Array[Any](row.size)
+ var idx = 0
+ while (idx < row.size) {
+ ar(idx) = converters(idx).toCatalyst(row(idx))
+ idx += 1
}
-
- case mapType: MapType =>
- val keyConverter = createToCatalystConverter(mapType.keyType)
- val valueConverter = createToCatalystConverter(mapType.valueType)
- (item: Any) => {
- extractOption(item) match {
- case m: Map[_, _] =>
- m.map { case (k, v) =>
- keyConverter(k) -> valueConverter(v)
- }
-
- case jmap: JavaMap[_, _] =>
- val iter = jmap.entrySet.iterator
- val convertedMap: HashMap[Any, Any] = HashMap()
- while (iter.hasNext) {
- val entry = iter.next()
- convertedMap(keyConverter(entry.getKey)) = valueConverter(entry.getValue)
- }
- convertedMap
-
- case null => null
- }
+ new GenericRowWithSchema(ar, structType)
+
+ case p: Product =>
+ val ar = new Array[Any](structType.size)
+ val iter = p.productIterator
+ var idx = 0
+ while (idx < structType.size) {
+ ar(idx) = converters(idx).toCatalyst(iter.next())
+ idx += 1
}
+ new GenericRowWithSchema(ar, structType)
+ }
- case structType: StructType =>
- val converters = structType.fields.map(f => createToCatalystConverter(f.dataType))
- (item: Any) => {
- extractOption(item) match {
- case r: Row =>
- convertRowWithConverters(r, structType, converters)
-
- case p: Product =>
- val ar = new Array[Any](structType.size)
- val iter = p.productIterator
- var idx = 0
- while (idx < structType.size) {
- ar(idx) = converters(idx)(iter.next())
- idx += 1
- }
- new GenericRowWithSchema(ar, structType)
-
- case null =>
- null
- }
+ override def toScala(row: Row): Row = {
+ if (row == null) {
+ null
+ } else {
+ val ar = new Array[Any](row.size)
+ var idx = 0
+ while (idx < row.size) {
+ ar(idx) = converters(idx).toScala(row, idx)
+ idx += 1
}
-
- case dateType: DateType => (item: Any) => extractOption(item) match {
- case d: java.sql.Date => DateUtils.fromJavaDate(d)
- case other => other
+ new GenericRowWithSchema(ar, structType)
}
+ }
- case dataType: StringType => (item: Any) => extractOption(item) match {
- case s: String => UTF8String(s)
- case other => other
- }
+ override def toScalaImpl(row: Row, column: Int): Row = toScala(row(column).asInstanceOf[Row])
+ }
+
+ private object StringConverter extends CatalystTypeConverter[Any, String, Any] {
+ override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match {
+ case str: String => UTF8String(str)
+ case utf8: UTF8String => utf8
+ }
+ override def toScala(catalystValue: Any): String = catalystValue match {
+ case null => null
+ case str: String => str
+ case utf8: UTF8String => utf8.toString()
+ }
+ override def toScalaImpl(row: Row, column: Int): String = row(column).toString
+ }
+
+ private object DateConverter extends CatalystTypeConverter[Date, Date, Any] {
+ override def toCatalystImpl(scalaValue: Date): Int = DateUtils.fromJavaDate(scalaValue)
+ override def toScala(catalystValue: Any): Date =
+ if (catalystValue == null) null else DateUtils.toJavaDate(catalystValue.asInstanceOf[Int])
+ override def toScalaImpl(row: Row, column: Int): Date = toScala(row.getInt(column))
+ }
+
+ private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
+ override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match {
+ case d: BigDecimal => Decimal(d)
+ case d: JavaBigDecimal => Decimal(d)
+ case d: Decimal => d
+ }
+ override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal
+ override def toScalaImpl(row: Row, column: Int): JavaBigDecimal = row.get(column) match {
+ case d: JavaBigDecimal => d
+ case d: Decimal => d.toJavaBigDecimal
+ }
+ }
+
+ private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] {
+ final override def toScala(catalystValue: Any): Any = catalystValue
+ final override def toCatalystImpl(scalaValue: T): Any = scalaValue
+ }
+
+ private object BooleanConverter extends PrimitiveConverter[Boolean] {
+ override def toScalaImpl(row: Row, column: Int): Boolean = row.getBoolean(column)
+ }
+
+ private object ByteConverter extends PrimitiveConverter[Byte] {
+ override def toScalaImpl(row: Row, column: Int): Byte = row.getByte(column)
+ }
+
+ private object ShortConverter extends PrimitiveConverter[Short] {
+ override def toScalaImpl(row: Row, column: Int): Short = row.getShort(column)
+ }
+
+ private object IntConverter extends PrimitiveConverter[Int] {
+ override def toScalaImpl(row: Row, column: Int): Int = row.getInt(column)
+ }
+
+ private object LongConverter extends PrimitiveConverter[Long] {
+ override def toScalaImpl(row: Row, column: Int): Long = row.getLong(column)
+ }
+
+ private object FloatConverter extends PrimitiveConverter[Float] {
+ override def toScalaImpl(row: Row, column: Int): Float = row.getFloat(column)
+ }
- case _ =>
- (item: Any) => extractOption(item) match {
- case d: BigDecimal => Decimal(d)
- case d: java.math.BigDecimal => Decimal(d)
- case other => other
+ private object DoubleConverter extends PrimitiveConverter[Double] {
+ override def toScalaImpl(row: Row, column: Int): Double = row.getDouble(column)
+ }
+
+ /**
+ * Converts Scala objects to catalyst rows / types. This method is slow, and for batch
+ * conversion you should be using converter produced by createToCatalystConverter.
+ * Note: This is always called after schemaFor has been called.
+ * This ordering is important for UDT registration.
+ */
+ def convertToCatalyst(scalaValue: Any, dataType: DataType): Any = {
+ getConverterForType(dataType).toCatalyst(scalaValue)
+ }
+
+ /**
+ * Creates a converter function that will convert Scala objects to the specified Catalyst type.
+ * Typical use case would be converting a collection of rows that have the same schema. You will
+ * call this function once to get a converter, and apply it to every row.
+ */
+ private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = {
+ if (isPrimitive(dataType)) {
+ // Although the `else` branch here is capable of handling inbound conversion of primitives,
+ // we add some special-case handling for those types here. The motivation for this relates to
+ // Java method invocation costs: if we have rows that consist entirely of primitive columns,
+ // then returning the same conversion function for all of the columns means that the call site
+ // will be monomorphic instead of polymorphic. In microbenchmarks, this actually resulted in
+ // a measurable performance impact. Note that this optimization will be unnecessary if we
+ // use code generation to construct Scala Row -> Catalyst Row converters.
+ def convert(maybeScalaValue: Any): Any = {
+ if (maybeScalaValue.isInstanceOf[Option[Any]]) {
+ maybeScalaValue.asInstanceOf[Option[Any]].orNull
+ } else {
+ maybeScalaValue
}
+ }
+ convert
+ } else {
+ getConverterForType(dataType).toCatalyst
}
}
/**
- * Converts Scala objects to catalyst rows / types.
+ * Converts Scala objects to Catalyst rows / types.
*
* Note: This should be called before do evaluation on Row
* (It does not support UDT)
* This is used to create an RDD or test results with correct types for Catalyst.
*/
def convertToCatalyst(a: Any): Any = a match {
- case s: String => UTF8String(s)
- case d: java.sql.Date => DateUtils.fromJavaDate(d)
- case d: BigDecimal => Decimal(d)
- case d: java.math.BigDecimal => Decimal(d)
+ case s: String => StringConverter.toCatalyst(s)
+ case d: Date => DateConverter.toCatalyst(d)
+ case d: BigDecimal => BigDecimalConverter.toCatalyst(d)
+ case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d)
case seq: Seq[Any] => seq.map(convertToCatalyst)
case r: Row => Row(r.toSeq.map(convertToCatalyst): _*)
case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray
@@ -238,33 +382,8 @@ object CatalystTypeConverters {
* This method is slow, and for batch conversion you should be using converter
* produced by createToScalaConverter.
*/
- def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match {
- // Check UDT first since UDTs can override other types
- case (d, udt: UserDefinedType[_]) =>
- udt.deserialize(d)
-
- case (s: Seq[_], arrayType: ArrayType) =>
- s.map(convertToScala(_, arrayType.elementType))
-
- case (m: Map[_, _], mapType: MapType) =>
- m.map { case (k, v) =>
- convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
- }
-
- case (r: Row, s: StructType) =>
- convertRowToScala(r, s)
-
- case (d: Decimal, _: DecimalType) =>
- d.toJavaBigDecimal
-
- case (i: Int, DateType) =>
- DateUtils.toJavaDate(i)
-
- case (s: UTF8String, StringType) =>
- s.toString()
-
- case (other, _) =>
- other
+ def convertToScala(catalystValue: Any, dataType: DataType): Any = {
+ getConverterForType(dataType).toScala(catalystValue)
}
/**
@@ -272,82 +391,7 @@ object CatalystTypeConverters {
* Typical use case would be converting a collection of rows that have the same schema. You will
* call this function once to get a converter, and apply it to every row.
*/
- private[sql] def createToScalaConverter(dataType: DataType): Any => Any = dataType match {
- // Check UDT first since UDTs can override other types
- case udt: UserDefinedType[_] =>
- (item: Any) => if (item == null) null else udt.deserialize(item)
-
- case arrayType: ArrayType =>
- val elementConverter = createToScalaConverter(arrayType.elementType)
- (item: Any) => if (item == null) null else item.asInstanceOf[Seq[_]].map(elementConverter)
-
- case mapType: MapType =>
- val keyConverter = createToScalaConverter(mapType.keyType)
- val valueConverter = createToScalaConverter(mapType.valueType)
- (item: Any) => if (item == null) {
- null
- } else {
- item.asInstanceOf[Map[_, _]].map { case (k, v) =>
- keyConverter(k) -> valueConverter(v)
- }
- }
-
- case s: StructType =>
- val converters = s.fields.map(f => createToScalaConverter(f.dataType))
- (item: Any) => {
- if (item == null) {
- null
- } else {
- convertRowWithConverters(item.asInstanceOf[Row], s, converters)
- }
- }
-
- case _: DecimalType =>
- (item: Any) => item match {
- case d: Decimal => d.toJavaBigDecimal
- case other => other
- }
-
- case DateType =>
- (item: Any) => item match {
- case i: Int => DateUtils.toJavaDate(i)
- case other => other
- }
-
- case StringType =>
- (item: Any) => item match {
- case s: UTF8String => s.toString()
- case other => other
- }
-
- case other =>
- (item: Any) => item
- }
-
- def convertRowToScala(r: Row, schema: StructType): Row = {
- val ar = new Array[Any](r.size)
- var idx = 0
- while (idx < r.size) {
- ar(idx) = convertToScala(r(idx), schema.fields(idx).dataType)
- idx += 1
- }
- new GenericRowWithSchema(ar, schema)
- }
-
- /**
- * Converts a row by applying the provided set of converter functions. It is used for both
- * toScala and toCatalyst conversions.
- */
- private[sql] def convertRowWithConverters(
- row: Row,
- schema: StructType,
- converters: Array[Any => Any]): Row = {
- val ar = new Array[Any](row.size)
- var idx = 0
- while (idx < row.size) {
- ar(idx) = converters(idx)(row(idx))
- idx += 1
- }
- new GenericRowWithSchema(ar, schema)
+ private[sql] def createToScalaConverter(dataType: DataType): Any => Any = {
+ getConverterForType(dataType).toScala
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 634138010fd21..b6191eafba71b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -71,12 +71,23 @@ case class UserDefinedGenerator(
children: Seq[Expression])
extends Generator {
+ @transient private[this] var inputRow: InterpretedProjection = _
+ @transient private[this] var convertToScala: (Row) => Row = _
+
+ private def initializeConverters(): Unit = {
+ inputRow = new InterpretedProjection(children)
+ convertToScala = {
+ val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
+ CatalystTypeConverters.createToScalaConverter(inputSchema)
+ }.asInstanceOf[(Row => Row)]
+ }
+
override def eval(input: Row): TraversableOnce[Row] = {
- // TODO(davies): improve this
+ if (inputRow == null) {
+ initializeConverters()
+ }
// Convert the objects into Scala Type before calling function, we need schema to support UDT
- val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
- val inputRow = new InterpretedProjection(children)
- function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row])
+ function(convertToScala(inputRow(input)))
}
override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})"
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
new file mode 100644
index 0000000000000..df0f04563edcf
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+
+class CatalystTypeConvertersSuite extends SparkFunSuite {
+
+ private val simpleTypes: Seq[DataType] = Seq(
+ StringType,
+ DateType,
+ BooleanType,
+ ByteType,
+ ShortType,
+ IntegerType,
+ LongType,
+ FloatType,
+ DoubleType)
+
+ test("null handling in rows") {
+ val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t)))
+ val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema)
+ val convertToScala = CatalystTypeConverters.createToScalaConverter(schema)
+
+ val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null))
+ assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow)
+ }
+
+ test("null handling for individual values") {
+ for (dataType <- simpleTypes) {
+ assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null)
+ }
+ }
+
+ test("option handling in convertToCatalyst") {
+ // convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with
+ // createToCatalystConverter but it may not actually matter as this is only called internally
+ // in a handful of places where we don't expect to receive Options.
+ assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123))
+ }
+
+ test("option handling in createToCatalystConverter") {
+ assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index e439a18ac43aa..824ae36968c32 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -190,7 +190,7 @@ private[sql] class ParquetRelation2(
}
}
- override def dataSchema: StructType = metadataCache.dataSchema
+ override def dataSchema: StructType = maybeDataSchema.getOrElse(metadataCache.dataSchema)
override private[sql] def refresh(): Unit = {
super.refresh()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
index 3132067d562f6..71f016b1f14de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
@@ -30,9 +30,10 @@ import org.apache.spark._
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext, SaveMode}
@@ -94,10 +95,19 @@ private[sql] case class InsertIntoHadoopFsRelation(
// We create a DataFrame by applying the schema of relation to the data to make sure.
// We are writing data based on the expected schema,
- val df = sqlContext.createDataFrame(
- DataFrame(sqlContext, query).queryExecution.toRdd,
- relation.schema,
- needsConversion = false)
+ val df = {
+ // For partitioned relation r, r.schema's column ordering can be different from the column
+ // ordering of data.logicalPlan (partition columns are all moved after data column). We
+ // need a Project to adjust the ordering, so that inside InsertIntoHadoopFsRelation, we can
+ // safely apply the schema of r.schema to the data.
+ val project = Project(
+ relation.schema.map(field => new UnresolvedAttribute(Seq(field.name))), query)
+
+ sqlContext.createDataFrame(
+ DataFrame(sqlContext, project).queryExecution.toRdd,
+ relation.schema,
+ needsConversion = false)
+ }
val partitionColumns = relation.partitionColumns.fieldNames
if (partitionColumns.isEmpty) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index 22587f5a1c6f1..20afd60cb7767 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.Logging
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.RunnableCommand
@@ -322,19 +322,13 @@ private[sql] object ResolvedDataSource {
Some(partitionColumnsSchema(data.schema, partitionColumns)),
caseInsensitiveOptions)
- // For partitioned relation r, r.schema's column ordering is different with the column
- // ordering of data.logicalPlan. We need a Project to adjust the ordering.
- // So, inside InsertIntoHadoopFsRelation, we can safely apply the schema of r.schema to
- // the data.
- val project =
- Project(
- r.schema.map(field => new UnresolvedAttribute(Seq(field.name))),
- data.logicalPlan)
-
+ // For partitioned relation r, r.schema's column ordering can be different from the column
+ // ordering of data.logicalPlan (partition columns are all moved after data column). This
+ // will be adjusted within InsertIntoHadoopFsRelation.
sqlContext.executePlan(
InsertIntoHadoopFsRelation(
r,
- project,
+ data.logicalPlan,
mode)).toRdd
r
case _ =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index c4ffa8de52640..f5bd2d2941ca0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -503,7 +503,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
*/
override lazy val schema: StructType = {
val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet
- StructType(dataSchema ++ partitionSpec.partitionColumns.filterNot { column =>
+ StructType(dataSchema ++ partitionColumns.filterNot { column =>
dataSchemaColumnNames.contains(column.name.toLowerCase)
})
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 56591d9dba29e..055453e688e73 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -173,7 +173,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
new Timestamp(i),
(1 to i).toSeq,
(0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap,
- Row((i - 0.25).toFloat, (1 to i).toSeq))
+ Row((i - 0.25).toFloat, Seq(true, false, null)))
}
createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
// Cache the table.
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
index da511ebd05ad2..a93a3dee43511 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
@@ -19,8 +19,6 @@ package org.apache.spark.sql.hive.thriftserver
import java.io.File
import java.net.URL
-import java.nio.charset.StandardCharsets
-import java.nio.file.{Files, Paths}
import java.sql.{Date, DriverManager, Statement}
import scala.collection.mutable.ArrayBuffer
@@ -29,6 +27,8 @@ import scala.concurrent.{Await, Promise}
import scala.sys.process.{Process, ProcessLogger}
import scala.util.{Random, Try}
+import com.google.common.base.Charsets.UTF_8
+import com.google.common.io.Files
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.jdbc.HiveDriver
import org.apache.hive.service.auth.PlainSaslHelper
@@ -441,13 +441,14 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl
val tempLog4jConf = Utils.createTempDir().getCanonicalPath
Files.write(
- Paths.get(s"$tempLog4jConf/log4j.properties"),
"""log4j.rootCategory=INFO, console
|log4j.appender.console=org.apache.log4j.ConsoleAppender
|log4j.appender.console.target=System.err
|log4j.appender.console.layout=org.apache.log4j.PatternLayout
|log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
- """.stripMargin.getBytes(StandardCharsets.UTF_8))
+ """.stripMargin,
+ new File(s"$tempLog4jConf/log4j.properties"),
+ UTF_8)
tempLog4jConf + File.pathSeparator + sys.props("java.class.path")
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
index af36fa6f1faae..74095426741e3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.sources
+import java.io.File
+
+import com.google.common.io.Files
import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkException, SparkFunSuite}
@@ -453,6 +456,20 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
}
}
}
+
+ test("SPARK-7616: adjust column name order accordingly when saving partitioned table") {
+ val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c")
+
+ df.write
+ .format(dataSourceName)
+ .mode(SaveMode.Overwrite)
+ .partitionBy("c", "a")
+ .saveAsTable("t")
+
+ withTable("t") {
+ checkAnswer(table("t"), df.select('b, 'c, 'a).collect())
+ }
+ }
}
class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
@@ -534,20 +551,6 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
}
}
- test("SPARK-7616: adjust column name order accordingly when saving partitioned table") {
- val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c")
-
- df.write
- .format("parquet")
- .mode(SaveMode.Overwrite)
- .partitionBy("c", "a")
- .saveAsTable("t")
-
- withTable("t") {
- checkAnswer(table("t"), df.select('b, 'c, 'a).collect())
- }
- }
-
test("SPARK-7868: _temporary directories should be ignored") {
withTempPath { dir =>
val df = Seq("a", "b", "c").zipWithIndex.toDF()
@@ -563,4 +566,32 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect())
}
}
+
+ test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") {
+ withTempDir { dir =>
+ val path = dir.getCanonicalPath
+ val df = Seq(1 -> "a").toDF()
+
+ // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw
+ // since it's not a valid Parquet file.
+ val emptyFile = new File(path, "empty")
+ Files.createParentDirs(emptyFile)
+ Files.touch(emptyFile)
+
+ // This shouldn't throw anything.
+ df.write.format("parquet").mode(SaveMode.Ignore).save(path)
+
+ // This should only complain that the destination directory already exists, rather than file
+ // "empty" is not a Parquet file.
+ assert {
+ intercept[RuntimeException] {
+ df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path)
+ }.getMessage.contains("already exists")
+ }
+
+ // This shouldn't throw anything.
+ df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
+ checkAnswer(read.format("parquet").load(path), df)
+ }
+ }
}