The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint.
-The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
+The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
values. We compute the mean squared error at the end to evaluate
[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit).
@@ -706,8 +709,8 @@ a dependency.
###Streaming linear regression
-When data arrive in a streaming fashion, it is useful to fit regression models online,
-updating the parameters of the model as new data arrives. MLlib currently supports
+When data arrive in a streaming fashion, it is useful to fit regression models online,
+updating the parameters of the model as new data arrives. MLlib currently supports
streaming linear regression using ordinary least squares. The fitting is similar
to that performed offline, except fitting occurs on each batch of data, so that
the model continually updates to reflect the data from the stream.
@@ -722,7 +725,7 @@ online to the first stream, and make predictions on the second stream.
-First, we import the necessary classes for parsing our input data and creating the model.
+First, we import the necessary classes for parsing our input data and creating the model.
{% highlight scala %}
@@ -734,7 +737,7 @@ import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD
Then we make input streams for training and testing data. We assume a StreamingContext `ssc`
has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing)
-for more info. For this example, we use labeled points in training and testing streams,
+for more info. For this example, we use labeled points in training and testing streams,
but in practice you will likely want to use unlabeled vectors for test data.
{% highlight scala %}
@@ -754,7 +757,7 @@ val model = new StreamingLinearRegressionWithSGD()
{% endhighlight %}
-Now we register the streams for training and testing and start the job.
+Now we register the streams for training and testing and start the job.
Printing predictions alongside true labels lets us easily see the result.
{% highlight scala %}
@@ -764,14 +767,14 @@ model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()
ssc.start()
ssc.awaitTermination()
-
+
{% endhighlight %}
We can now save text files with data to the training or testing folders.
-Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label
-and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir`
-the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions.
-As you feed more data to the training directory, the predictions
+Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label
+and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir`
+the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions.
+As you feed more data to the training directory, the predictions
will get better!
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
new file mode 100644
index 0000000000000..d9a36bda386b3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.util.parsing.combinator.RegexParsers
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.param.{Param, ParamMap}
+import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+/**
+ * :: Experimental ::
+ * Implements the transforms required for fitting a dataset against an R model formula. Currently
+ * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula
+ * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
+ */
+@Experimental
+class RFormula(override val uid: String)
+ extends Transformer with HasFeaturesCol with HasLabelCol {
+
+ def this() = this(Identifiable.randomUID("rFormula"))
+
+ /**
+ * R formula parameter. The formula is provided in string form.
+ * @group setParam
+ */
+ val formula: Param[String] = new Param(this, "formula", "R model formula")
+
+ private var parsedFormula: Option[ParsedRFormula] = None
+
+ /**
+ * Sets the formula to use for this transformer. Must be called before use.
+ * @group setParam
+ * @param value an R formula in string form (e.g. "y ~ x + z")
+ */
+ def setFormula(value: String): this.type = {
+ parsedFormula = Some(RFormulaParser.parse(value))
+ set(formula, value)
+ this
+ }
+
+ /** @group getParam */
+ def getFormula: String = $(formula)
+
+ /** @group getParam */
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group getParam */
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
+ override def transformSchema(schema: StructType): StructType = {
+ checkCanTransform(schema)
+ val withFeatures = transformFeatures.transformSchema(schema)
+ if (hasLabelCol(schema)) {
+ withFeatures
+ } else {
+ val nullable = schema(parsedFormula.get.label).dataType match {
+ case _: NumericType | BooleanType => false
+ case _ => true
+ }
+ StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable))
+ }
+ }
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ checkCanTransform(dataset.schema)
+ transformLabel(transformFeatures.transform(dataset))
+ }
+
+ override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
+
+ override def toString: String = s"RFormula(${get(formula)})"
+
+ private def transformLabel(dataset: DataFrame): DataFrame = {
+ if (hasLabelCol(dataset.schema)) {
+ dataset
+ } else {
+ val labelName = parsedFormula.get.label
+ dataset.schema(labelName).dataType match {
+ case _: NumericType | BooleanType =>
+ dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
+ // TODO(ekl) add support for string-type labels
+ case other =>
+ throw new IllegalArgumentException("Unsupported type for label: " + other)
+ }
+ }
+ }
+
+ private def transformFeatures: Transformer = {
+ // TODO(ekl) add support for non-numeric features and feature interactions
+ new VectorAssembler(uid)
+ .setInputCols(parsedFormula.get.terms.toArray)
+ .setOutputCol($(featuresCol))
+ }
+
+ private def checkCanTransform(schema: StructType) {
+ require(parsedFormula.isDefined, "Must call setFormula() first.")
+ val columnNames = schema.map(_.name)
+ require(!columnNames.contains($(featuresCol)), "Features column already exists.")
+ require(
+ !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
+ "Label column already exists and is not of type DoubleType.")
+ }
+
+ private def hasLabelCol(schema: StructType): Boolean = {
+ schema.map(_.name).contains($(labelCol))
+ }
+}
+
+/**
+ * Represents a parsed R formula.
+ */
+private[ml] case class ParsedRFormula(label: String, terms: Seq[String])
+
+/**
+ * Limited implementation of R formula parsing. Currently supports: '~', '+'.
+ */
+private[ml] object RFormulaParser extends RegexParsers {
+ def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r
+
+ def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list }
+
+ def formula: Parser[ParsedRFormula] =
+ (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
+
+ def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
+ case Success(result, _) => result
+ case failure: NoSuccess => throw new IllegalArgumentException(
+ "Could not parse formula: " + value)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 9f83c2ee16178..086917fa680f8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -116,7 +116,7 @@ class VectorAssembler(override val uid: String)
if (schema.fieldNames.contains(outputColName)) {
throw new IllegalArgumentException(s"Output column $outputColName already exists.")
}
- StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
+ StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true))
}
override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
new file mode 100644
index 0000000000000..5fdf878a3df72
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import scala.collection.mutable
+
+import org.apache.spark.{Accumulator, SparkContext}
+
+/**
+ * Abstract class for stopwatches.
+ */
+private[spark] abstract class Stopwatch extends Serializable {
+
+ @transient private var running: Boolean = false
+ private var startTime: Long = _
+
+ /**
+ * Name of the stopwatch.
+ */
+ val name: String
+
+ /**
+ * Starts the stopwatch.
+ * Throws an exception if the stopwatch is already running.
+ */
+ def start(): Unit = {
+ assume(!running, "start() called but the stopwatch is already running.")
+ running = true
+ startTime = now
+ }
+
+ /**
+ * Stops the stopwatch and returns the duration of the last session in milliseconds.
+ * Throws an exception if the stopwatch is not running.
+ */
+ def stop(): Long = {
+ assume(running, "stop() called but the stopwatch is not running.")
+ val duration = now - startTime
+ add(duration)
+ running = false
+ duration
+ }
+
+ /**
+ * Checks whether the stopwatch is running.
+ */
+ def isRunning: Boolean = running
+
+ /**
+ * Returns total elapsed time in milliseconds, not counting the current session if the stopwatch
+ * is running.
+ */
+ def elapsed(): Long
+
+ /**
+ * Gets the current time in milliseconds.
+ */
+ protected def now: Long = System.currentTimeMillis()
+
+ /**
+ * Adds input duration to total elapsed time.
+ */
+ protected def add(duration: Long): Unit
+}
+
+/**
+ * A local [[Stopwatch]].
+ */
+private[spark] class LocalStopwatch(override val name: String) extends Stopwatch {
+
+ private var elapsedTime: Long = 0L
+
+ override def elapsed(): Long = elapsedTime
+
+ override protected def add(duration: Long): Unit = {
+ elapsedTime += duration
+ }
+}
+
+/**
+ * A distributed [[Stopwatch]] using Spark accumulator.
+ * @param sc SparkContext
+ */
+private[spark] class DistributedStopwatch(
+ sc: SparkContext,
+ override val name: String) extends Stopwatch {
+
+ private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)")
+
+ override def elapsed(): Long = elapsedTime.value
+
+ override protected def add(duration: Long): Unit = {
+ elapsedTime += duration
+ }
+}
+
+/**
+ * A multiple stopwatch that contains local and distributed stopwatches.
+ * @param sc SparkContext
+ */
+private[spark] class MultiStopwatch(@transient private val sc: SparkContext) extends Serializable {
+
+ private val stopwatches: mutable.Map[String, Stopwatch] = mutable.Map.empty
+
+ /**
+ * Adds a local stopwatch.
+ * @param name stopwatch name
+ */
+ def addLocal(name: String): this.type = {
+ require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
+ stopwatches(name) = new LocalStopwatch(name)
+ this
+ }
+
+ /**
+ * Adds a distributed stopwatch.
+ * @param name stopwatch name
+ */
+ def addDistributed(name: String): this.type = {
+ require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
+ stopwatches(name) = new DistributedStopwatch(sc, name)
+ this
+ }
+
+ /**
+ * Gets a stopwatch.
+ * @param name stopwatch name
+ */
+ def apply(name: String): Stopwatch = stopwatches(name)
+
+ override def toString: String = {
+ stopwatches.values.toArray.sortBy(_.name)
+ .map(c => s" ${c.name}: ${c.elapsed()}ms")
+ .mkString("{\n", ",\n", "\n}")
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index e628059c4af8e..c58a64001d9a0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -502,6 +502,39 @@ private[python] class PythonMLLibAPI extends Serializable {
new MatrixFactorizationModelWrapper(model)
}
+ /**
+ * Java stub for Python mllib LDA.run()
+ */
+ def trainLDAModel(
+ data: JavaRDD[java.util.List[Any]],
+ k: Int,
+ maxIterations: Int,
+ docConcentration: Double,
+ topicConcentration: Double,
+ seed: java.lang.Long,
+ checkpointInterval: Int,
+ optimizer: String): LDAModel = {
+ val algo = new LDA()
+ .setK(k)
+ .setMaxIterations(maxIterations)
+ .setDocConcentration(docConcentration)
+ .setTopicConcentration(topicConcentration)
+ .setCheckpointInterval(checkpointInterval)
+ .setOptimizer(optimizer)
+
+ if (seed != null) algo.setSeed(seed)
+
+ val documents = data.rdd.map(_.asScala.toArray).map { r =>
+ r(0) match {
+ case i: java.lang.Integer => (i.toLong, r(1).asInstanceOf[Vector])
+ case i: java.lang.Long => (i.toLong, r(1).asInstanceOf[Vector])
+ case _ => throw new IllegalArgumentException("input values contains invalid type value.")
+ }
+ }
+ algo.run(documents)
+ }
+
+
/**
* Java stub for Python mllib FPGrowth.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 0f8d6a399682d..68297130a7b03 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -156,6 +156,21 @@ class KMeans private (
this
}
+ // Initial cluster centers can be provided as a KMeansModel object rather than using the
+ // random or k-means|| initializationMode
+ private var initialModel: Option[KMeansModel] = None
+
+ /**
+ * Set the initial starting point, bypassing the random initialization or k-means||
+ * The condition model.k == this.k must be met, failure results
+ * in an IllegalArgumentException.
+ */
+ def setInitialModel(model: KMeansModel): this.type = {
+ require(model.k == k, "mismatched cluster count")
+ initialModel = Some(model)
+ this
+ }
+
/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
@@ -193,20 +208,34 @@ class KMeans private (
val initStartTime = System.nanoTime()
- val centers = if (initializationMode == KMeans.RANDOM) {
- initRandom(data)
+ // Only one run is allowed when initialModel is given
+ val numRuns = if (initialModel.nonEmpty) {
+ if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.")
+ 1
} else {
- initKMeansParallel(data)
+ runs
}
+ val centers = initialModel match {
+ case Some(kMeansCenters) => {
+ Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s)))
+ }
+ case None => {
+ if (initializationMode == KMeans.RANDOM) {
+ initRandom(data)
+ } else {
+ initKMeansParallel(data)
+ }
+ }
+ }
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
" seconds.")
- val active = Array.fill(runs)(true)
- val costs = Array.fill(runs)(0.0)
+ val active = Array.fill(numRuns)(true)
+ val costs = Array.fill(numRuns)(0.0)
- var activeRuns = new ArrayBuffer[Int] ++ (0 until runs)
+ var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns)
var iteration = 0
val iterationStartTime = System.nanoTime()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
index e577bf87f885e..408847afa800d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
@@ -53,14 +53,22 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
)
summary
}
+ private lazy val SSerr = math.pow(summary.normL2(1), 2)
+ private lazy val SStot = summary.variance(0) * (summary.count - 1)
+ private lazy val SSreg = {
+ val yMean = summary.mean(0)
+ predictionAndObservations.map {
+ case (prediction, _) => math.pow(prediction - yMean, 2)
+ }.sum()
+ }
/**
- * Returns the explained variance regression score.
- * explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
- * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]]
+ * Returns the variance explained by regression.
+ * explainedVariance = \sum_i (\hat{y_i} - \bar{y})^2 / n
+ * @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]]
*/
def explainedVariance: Double = {
- 1 - summary.variance(1) / summary.variance(0)
+ SSreg / summary.count
}
/**
@@ -76,8 +84,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
* expected value of the squared error loss or quadratic loss.
*/
def meanSquaredError: Double = {
- val rmse = summary.normL2(1) / math.sqrt(summary.count)
- rmse * rmse
+ SSerr / summary.count
}
/**
@@ -85,14 +92,14 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
* the mean squared error.
*/
def rootMeanSquaredError: Double = {
- summary.normL2(1) / math.sqrt(summary.count)
+ math.sqrt(this.meanSquaredError)
}
/**
- * Returns R^2^, the coefficient of determination.
- * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
+ * Returns R^2^, the unadjusted coefficient of determination.
+ * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
*/
def r2: Double = {
- 1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1))
+ 1 - SSerr / SStot
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
index 39c48b084e550..7ead6327486cc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
@@ -17,58 +17,49 @@
package org.apache.spark.mllib.fpm
+import scala.collection.mutable
+
import org.apache.spark.Logging
-import org.apache.spark.annotation.Experimental
/**
- *
- * :: Experimental ::
- *
* Calculate all patterns of a projected database in local.
*/
-@Experimental
private[fpm] object LocalPrefixSpan extends Logging with Serializable {
/**
* Calculate all patterns of a projected database.
* @param minCount minimum count
* @param maxPatternLength maximum pattern length
- * @param prefix prefix
- * @param projectedDatabase the projected dabase
+ * @param prefixes prefixes in reversed order
+ * @param database the projected database
* @return a set of sequential pattern pairs,
- * the key of pair is sequential pattern (a list of items),
+ * the key of pair is sequential pattern (a list of items in reversed order),
* the value of pair is the pattern's count.
*/
def run(
minCount: Long,
maxPatternLength: Int,
- prefix: Array[Int],
- projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
- val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
- val frequentPatternAndCounts = frequentPrefixAndCounts
- .map(x => (prefix ++ Array(x._1), x._2))
- val prefixProjectedDatabases = getPatternAndProjectedDatabase(
- prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
-
- val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
- if (continueProcess) {
- val nextPatterns = prefixProjectedDatabases
- .map(x => run(minCount, maxPatternLength, x._1, x._2))
- .reduce(_ ++ _)
- frequentPatternAndCounts ++ nextPatterns
- } else {
- frequentPatternAndCounts
+ prefixes: List[Int],
+ database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
+ if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
+ val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
+ val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
+ frequentItemAndCounts.iterator.flatMap { case (item, count) =>
+ val newPrefixes = item :: prefixes
+ val newProjected = project(filteredDatabase, item)
+ Iterator.single((newPrefixes, count)) ++
+ run(minCount, maxPatternLength, newPrefixes, newProjected)
}
}
/**
- * calculate suffix sequence following a prefix in a sequence
- * @param prefix prefix
- * @param sequence sequence
+ * Calculate suffix sequence immediately after the first occurrence of an item.
+ * @param item item to get suffix after
+ * @param sequence sequence to extract suffix from
* @return suffix sequence
*/
- def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
- val index = sequence.indexOf(prefix)
+ def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = {
+ val index = sequence.indexOf(item)
if (index == -1) {
Array()
} else {
@@ -76,38 +67,28 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
}
}
+ def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
+ database
+ .map(getSuffix(prefix, _))
+ .filter(_.nonEmpty)
+ }
+
/**
* Generates frequent items by filtering the input data using minimal count level.
- * @param minCount the absolute minimum count
- * @param sequences sequences data
- * @return array of item and count pair
+ * @param minCount the minimum count for an item to be frequent
+ * @param database database of sequences
+ * @return freq item to count map
*/
private def getFreqItemAndCounts(
minCount: Long,
- sequences: Array[Array[Int]]): Array[(Int, Long)] = {
- sequences.flatMap(_.distinct)
- .groupBy(x => x)
- .mapValues(_.length.toLong)
- .filter(_._2 >= minCount)
- .toArray
- }
-
- /**
- * Get the frequent prefixes' projected database.
- * @param prePrefix the frequent prefixes' prefix
- * @param frequentPrefixes frequent prefixes
- * @param sequences sequences data
- * @return prefixes and projected database
- */
- private def getPatternAndProjectedDatabase(
- prePrefix: Array[Int],
- frequentPrefixes: Array[Int],
- sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
- val filteredProjectedDatabase = sequences
- .map(x => x.filter(frequentPrefixes.contains(_)))
- frequentPrefixes.map { x =>
- val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
- (prePrefix ++ Array(x), sub)
- }.filter(x => x._2.nonEmpty)
+ database: Array[Array[Int]]): mutable.Map[Int, Long] = {
+ // TODO: use PrimitiveKeyOpenHashMap
+ val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
+ database.foreach { sequence =>
+ sequence.distinct.foreach { item =>
+ counts(item) += 1L
+ }
+ }
+ counts.filter(_._2 >= minCount)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
index aed7e30033b8a..139b2f6952fb8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -203,8 +203,12 @@ class PrefixSpan private (
private def getPatternsInLocal(
minCount: Long,
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(ArrayBuffer[Int], Long)] = {
- data
- .flatMap { x => LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2) }
- .map { case (pattern, count) => (pattern.to[ArrayBuffer], count) }
+ data.flatMap {
+ case (prefix, projDB) =>
+ LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB)
+ .map { case (pattern: List[Int], count: Long) =>
+ (pattern.toArray.reverse.to[ArrayBuffer], count)
+ }
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
new file mode 100644
index 0000000000000..c8d065f37a605
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+
+class RFormulaParserSuite extends SparkFunSuite {
+ private def checkParse(formula: String, label: String, terms: Seq[String]) {
+ val parsed = RFormulaParser.parse(formula)
+ assert(parsed.label == label)
+ assert(parsed.terms == terms)
+ }
+
+ test("parse simple formulas") {
+ checkParse("y ~ x", "y", Seq("x"))
+ checkParse("y ~ ._foo ", "y", Seq("._foo"))
+ checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
new file mode 100644
index 0000000000000..fa8611b243a9f
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new RFormula())
+ }
+
+ test("transform numeric data") {
+ val formula = new RFormula().setFormula("id ~ v1 + v2")
+ val original = sqlContext.createDataFrame(
+ Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
+ val result = formula.transform(original)
+ val resultSchema = formula.transformSchema(original.schema)
+ val expected = sqlContext.createDataFrame(
+ Seq(
+ (0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0),
+ (2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0))
+ ).toDF("id", "v1", "v2", "features", "label")
+ // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
+ assert(result.schema.toString == resultSchema.toString)
+ assert(resultSchema == expected.schema)
+ assert(result.collect().toSeq == expected.collect().toSeq)
+ }
+
+ test("features column already exists") {
+ val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
+ val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
+ intercept[IllegalArgumentException] {
+ formula.transformSchema(original.schema)
+ }
+ intercept[IllegalArgumentException] {
+ formula.transform(original)
+ }
+ }
+
+ test("label column already exists") {
+ val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
+ val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
+ val resultSchema = formula.transformSchema(original.schema)
+ assert(resultSchema.length == 3)
+ assert(resultSchema.toString == formula.transform(original).schema.toString)
+ }
+
+ test("label column already exists but is not double type") {
+ val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
+ val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
+ intercept[IllegalArgumentException] {
+ formula.transformSchema(original.schema)
+ }
+ intercept[IllegalArgumentException] {
+ formula.transform(original)
+ }
+ }
+
+// TODO(ekl) enable after we implement string label support
+// test("transform string label") {
+// val formula = new RFormula().setFormula("name ~ id")
+// val original = sqlContext.createDataFrame(
+// Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name")
+// val result = formula.transform(original)
+// val resultSchema = formula.transformSchema(original.schema)
+// val expected = sqlContext.createDataFrame(
+// Seq(
+// (1, "foo", Vectors.dense(Array(1.0)), 1.0),
+// (2, "bar", Vectors.dense(Array(2.0)), 0.0),
+// (3, "bar", Vectors.dense(Array(3.0)), 0.0))
+// ).toDF("id", "name", "features", "label")
+// assert(result.schema.toString == resultSchema.toString)
+// assert(result.collect().toSeq == expected.collect().toSeq)
+// }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
new file mode 100644
index 0000000000000..8df6617fe0228
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ private def testStopwatchOnDriver(sw: Stopwatch): Unit = {
+ assert(sw.name === "sw")
+ assert(sw.elapsed() === 0L)
+ assert(!sw.isRunning)
+ intercept[AssertionError] {
+ sw.stop()
+ }
+ sw.start()
+ Thread.sleep(50)
+ val duration = sw.stop()
+ assert(duration >= 50 && duration < 100) // using a loose upper bound
+ val elapsed = sw.elapsed()
+ assert(elapsed === duration)
+ sw.start()
+ Thread.sleep(50)
+ val duration2 = sw.stop()
+ assert(duration2 >= 50 && duration2 < 100)
+ val elapsed2 = sw.elapsed()
+ assert(elapsed2 === duration + duration2)
+ sw.start()
+ assert(sw.isRunning)
+ intercept[AssertionError] {
+ sw.start()
+ }
+ }
+
+ test("LocalStopwatch") {
+ val sw = new LocalStopwatch("sw")
+ testStopwatchOnDriver(sw)
+ }
+
+ test("DistributedStopwatch on driver") {
+ val sw = new DistributedStopwatch(sc, "sw")
+ testStopwatchOnDriver(sw)
+ }
+
+ test("DistributedStopwatch on executors") {
+ val sw = new DistributedStopwatch(sc, "sw")
+ val rdd = sc.parallelize(0 until 4, 4)
+ rdd.foreach { i =>
+ sw.start()
+ Thread.sleep(50)
+ sw.stop()
+ }
+ assert(!sw.isRunning)
+ val elapsed = sw.elapsed()
+ assert(elapsed >= 200 && elapsed < 400) // using a loose upper bound
+ }
+
+ test("MultiStopwatch") {
+ val sw = new MultiStopwatch(sc)
+ .addLocal("local")
+ .addDistributed("spark")
+ assert(sw("local").name === "local")
+ assert(sw("spark").name === "spark")
+ intercept[NoSuchElementException] {
+ sw("some")
+ }
+ assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}")
+ sw("local").start()
+ sw("spark").start()
+ Thread.sleep(50)
+ sw("local").stop()
+ Thread.sleep(50)
+ sw("spark").stop()
+ val localElapsed = sw("local").elapsed()
+ val sparkElapsed = sw("spark").elapsed()
+ assert(localElapsed >= 50 && localElapsed < 100)
+ assert(sparkElapsed >= 100 && sparkElapsed < 200)
+ assert(sw.toString ===
+ s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}")
+ val rdd = sc.parallelize(0 until 4, 4)
+ rdd.foreach { i =>
+ sw("local").start()
+ sw("spark").start()
+ Thread.sleep(50)
+ sw("spark").stop()
+ sw("local").stop()
+ }
+ val localElapsed2 = sw("local").elapsed()
+ assert(localElapsed2 === localElapsed)
+ val sparkElapsed2 = sw("spark").elapsed()
+ assert(sparkElapsed2 >= 300 && sparkElapsed2 < 600)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 0dbbd7127444f..3003c62d9876c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -278,6 +278,28 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
}
+
+ test("Initialize using given cluster centers") {
+ val points = Seq(
+ Vectors.dense(0.0, 0.0),
+ Vectors.dense(1.0, 0.0),
+ Vectors.dense(0.0, 1.0),
+ Vectors.dense(1.0, 1.0)
+ )
+ val rdd = sc.parallelize(points, 3)
+ // creating an initial model
+ val initialModel = new KMeansModel(Array(points(0), points(2)))
+
+ val returnModel = new KMeans()
+ .setK(2)
+ .setMaxIterations(0)
+ .setInitialModel(initialModel)
+ .run(rdd)
+ // comparing the returned model and the initial model
+ assert(returnModel.clusterCenters(0) === initialModel.clusterCenters(0))
+ assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1))
+ }
+
}
object KMeansSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
index 9de2bdb6d7246..4b7f1be58f99b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
@@ -23,24 +23,85 @@ import org.apache.spark.mllib.util.TestingUtils._
class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
- test("regression metrics") {
+ test("regression metrics for unbiased (includes intercept term) predictor") {
+ /* Verify results in R:
+ preds = c(2.25, -0.25, 1.75, 7.75)
+ obs = c(3.0, -0.5, 2.0, 7.0)
+
+ SStot = sum((obs - mean(obs))^2)
+ SSreg = sum((preds - mean(obs))^2)
+ SSerr = sum((obs - preds)^2)
+
+ explainedVariance = SSreg / length(obs)
+ explainedVariance
+ > [1] 8.796875
+ meanAbsoluteError = mean(abs(preds - obs))
+ meanAbsoluteError
+ > [1] 0.5
+ meanSquaredError = mean((preds - obs)^2)
+ meanSquaredError
+ > [1] 0.3125
+ rmse = sqrt(meanSquaredError)
+ rmse
+ > [1] 0.559017
+ r2 = 1 - SSerr / SStot
+ r2
+ > [1] 0.9571734
+ */
+ val predictionAndObservations = sc.parallelize(
+ Seq((2.25, 3.0), (-0.25, -0.5), (1.75, 2.0), (7.75, 7.0)), 2)
+ val metrics = new RegressionMetrics(predictionAndObservations)
+ assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5,
+ "explained variance regression score mismatch")
+ assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
+ assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch")
+ assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5,
+ "root mean squared error mismatch")
+ assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch")
+ }
+
+ test("regression metrics for biased (no intercept term) predictor") {
+ /* Verify results in R:
+ preds = c(2.5, 0.0, 2.0, 8.0)
+ obs = c(3.0, -0.5, 2.0, 7.0)
+
+ SStot = sum((obs - mean(obs))^2)
+ SSreg = sum((preds - mean(obs))^2)
+ SSerr = sum((obs - preds)^2)
+
+ explainedVariance = SSreg / length(obs)
+ explainedVariance
+ > [1] 8.859375
+ meanAbsoluteError = mean(abs(preds - obs))
+ meanAbsoluteError
+ > [1] 0.5
+ meanSquaredError = mean((preds - obs)^2)
+ meanSquaredError
+ > [1] 0.375
+ rmse = sqrt(meanSquaredError)
+ rmse
+ > [1] 0.6123724
+ r2 = 1 - SSerr / SStot
+ r2
+ > [1] 0.9486081
+ */
val predictionAndObservations = sc.parallelize(
Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2)
val metrics = new RegressionMetrics(predictionAndObservations)
- assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5,
+ assert(metrics.explainedVariance ~== 8.85937 absTol 1E-5,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch")
assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5,
"root mean squared error mismatch")
- assert(metrics.r2 ~== 0.94861 absTol 1E-5, "r2 score mismatch")
+ assert(metrics.r2 ~== 0.94860 absTol 1E-5, "r2 score mismatch")
}
test("regression metrics with complete fitting") {
val predictionAndObservations = sc.parallelize(
Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2)
val metrics = new RegressionMetrics(predictionAndObservations)
- assert(metrics.explainedVariance ~== 1.0 absTol 1E-5,
+ assert(metrics.explainedVariance ~== 8.6875 absTol 1E-5,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch")
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
index 413436d3db85f..9f107c89f6d80 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
@@ -18,9 +18,8 @@ package org.apache.spark.mllib.fpm
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.rdd.RDD
-class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
+class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
test("PrefixSpan using Integer type") {
@@ -48,15 +47,8 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
def compareResult(
expectedValue: Array[(Array[Int], Long)],
actualValue: Array[(Array[Int], Long)]): Boolean = {
- val sortedExpectedValue = expectedValue.sortWith{ (x, y) =>
- x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
- }
- val sortedActualValue = actualValue.sortWith{ (x, y) =>
- x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
- }
- sortedExpectedValue.zip(sortedActualValue)
- .map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2)
- .reduce(_&&_)
+ expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
+ actualValue.map(x => (x._1.toSeq, x._2)).toSet
}
val prefixspan = new PrefixSpan()
diff --git a/pom.xml b/pom.xml
index 370c95dd03632..aa49e2ab7294b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -748,6 +748,12 @@
curator-framework
${curator.version}
+
+ org.apache.curator
+ curator-test
+ ${curator.version}
+ test
+
org.apache.hadoop
hadoop-client
diff --git a/pylintrc b/pylintrc
new file mode 100644
index 0000000000000..061775960393b
--- /dev/null
+++ b/pylintrc
@@ -0,0 +1,404 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+[MASTER]
+
+# Specify a configuration file.
+#rcfile=
+
+# Python code to execute, usually for sys.path manipulation such as
+# pygtk.require().
+#init-hook=
+
+# Profiled execution.
+profile=no
+
+# Add files or directories to the blacklist. They should be base names, not
+# paths.
+ignore=pyspark.heapq3
+
+# Pickle collected data for later comparisons.
+persistent=yes
+
+# List of plugins (as comma separated values of python modules names) to load,
+# usually to register additional checkers.
+load-plugins=
+
+# Use multiple processes to speed up Pylint.
+jobs=1
+
+# Allow loading of arbitrary C extensions. Extensions are imported into the
+# active Python interpreter and may run arbitrary code.
+unsafe-load-any-extension=no
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code
+extension-pkg-whitelist=
+
+# Allow optimization of some AST trees. This will activate a peephole AST
+# optimizer, which will apply various small optimizations. For instance, it can
+# be used to obtain the result of joining multiple strings with the addition
+# operator. Joining a lot of strings can lead to a maximum recursion error in
+# Pylint and this flag can prevent that. It has one side effect, the resulting
+# AST will be different than the one from reality.
+optimize-ast=no
+
+
+[MESSAGES CONTROL]
+
+# Only show warnings with the listed confidence levels. Leave empty to show
+# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
+confidence=
+
+# Enable the message, report, category or checker with the given id(s). You can
+# either give multiple identifier separated by comma (,) or put this option
+# multiple time. See also the "--disable" option for examples.
+enable=
+
+# Disable the message, report, category or checker with the given id(s). You
+# can either give multiple identifiers separated by comma (,) or put this
+# option multiple times (only on the command line, not in the configuration
+# file where it should appear only once).You can also use "--disable=all" to
+# disable everything first and then reenable specific checks. For example, if
+# you want to run only the similarities checker, you can use "--disable=all
+# --enable=similarities". If you want to run only the classes checker, but have
+# no Warning level messages displayed, use"--disable=all --enable=classes
+# --disable=W"
+
+# These errors are arranged in order of number of warning given in pylint.
+# If you would like to improve the code quality of pyspark, remove any of these disabled errors
+# run ./dev/lint-python and see if the errors raised by pylint can be fixed.
+
+disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable
+
+
+[REPORTS]
+
+# Set the output format. Available formats are text, parseable, colorized, msvs
+# (visual studio) and html. You can also give a reporter class, eg
+# mypackage.mymodule.MyReporterClass.
+output-format=text
+
+# Put messages in a separate file for each module / package specified on the
+# command line instead of printing them on stdout. Reports (if any) will be
+# written in a file name "pylint_global.[txt|html]".
+files-output=no
+
+# Tells whether to display a full report or only the messages
+reports=no
+
+# Python expression which should return a note less than 10 (10 is the highest
+# note). You have access to the variables errors warning, statement which
+# respectively contain the number of errors / warnings messages and the total
+# number of statements analyzed. This is used by the global evaluation report
+# (RP0004).
+evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
+
+# Add a comment according to your evaluation note. This is used by the global
+# evaluation report (RP0004).
+comment=no
+
+# Template used to display messages. This is a python new-style format string
+# used to format the message information. See doc for all details
+#msg-template=
+
+
+[MISCELLANEOUS]
+
+# List of note tags to take in consideration, separated by a comma.
+notes=FIXME,XXX,TODO
+
+
+[BASIC]
+
+# Required attributes for module, separated by a comma
+required-attributes=
+
+# List of builtins function names that should not be used, separated by a comma
+bad-functions=
+
+# Good variable names which should always be accepted, separated by a comma
+good-names=i,j,k,ex,Run,_
+
+# Bad variable names which should always be refused, separated by a comma
+bad-names=baz,toto,tutu,tata
+
+# Colon-delimited sets of names that determine each other's naming style when
+# the name regexes allow several styles.
+name-group=
+
+# Include a hint for the correct naming format with invalid-name
+include-naming-hint=no
+
+# Regular expression matching correct function names
+function-rgx=[a-z_][a-z0-9_]{2,30}$
+
+# Naming hint for function names
+function-name-hint=[a-z_][a-z0-9_]{2,30}$
+
+# Regular expression matching correct variable names
+variable-rgx=[a-z_][a-z0-9_]{2,30}$
+
+# Naming hint for variable names
+variable-name-hint=[a-z_][a-z0-9_]{2,30}$
+
+# Regular expression matching correct constant names
+const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
+
+# Naming hint for constant names
+const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$
+
+# Regular expression matching correct attribute names
+attr-rgx=[a-z_][a-z0-9_]{2,30}$
+
+# Naming hint for attribute names
+attr-name-hint=[a-z_][a-z0-9_]{2,30}$
+
+# Regular expression matching correct argument names
+argument-rgx=[a-z_][a-z0-9_]{2,30}$
+
+# Naming hint for argument names
+argument-name-hint=[a-z_][a-z0-9_]{2,30}$
+
+# Regular expression matching correct class attribute names
+class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
+
+# Naming hint for class attribute names
+class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
+
+# Regular expression matching correct inline iteration names
+inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
+
+# Naming hint for inline iteration names
+inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
+
+# Regular expression matching correct class names
+class-rgx=[A-Z_][a-zA-Z0-9]+$
+
+# Naming hint for class names
+class-name-hint=[A-Z_][a-zA-Z0-9]+$
+
+# Regular expression matching correct module names
+module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
+
+# Naming hint for module names
+module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
+
+# Regular expression matching correct method names
+method-rgx=[a-z_][a-z0-9_]{2,30}$
+
+# Naming hint for method names
+method-name-hint=[a-z_][a-z0-9_]{2,30}$
+
+# Regular expression which should only match function or class names that do
+# not require a docstring.
+no-docstring-rgx=__.*__
+
+# Minimum line length for functions/classes that require docstrings, shorter
+# ones are exempt.
+docstring-min-length=-1
+
+
+[FORMAT]
+
+# Maximum number of characters on a single line.
+max-line-length=100
+
+# Regexp for a line that is allowed to be longer than the limit.
+ignore-long-lines=^\s*(# )??$
+
+# Allow the body of an if to be on the same line as the test if there is no
+# else.
+single-line-if-stmt=no
+
+# List of optional constructs for which whitespace checking is disabled
+no-space-check=trailing-comma,dict-separator
+
+# Maximum number of lines in a module
+max-module-lines=1000
+
+# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
+# tab).
+indent-string=' '
+
+# Number of spaces of indent required inside a hanging or continued line.
+indent-after-paren=4
+
+# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
+expected-line-ending-format=
+
+
+[SIMILARITIES]
+
+# Minimum lines number of a similarity.
+min-similarity-lines=4
+
+# Ignore comments when computing similarities.
+ignore-comments=yes
+
+# Ignore docstrings when computing similarities.
+ignore-docstrings=yes
+
+# Ignore imports when computing similarities.
+ignore-imports=no
+
+
+[VARIABLES]
+
+# Tells whether we should check for unused import in __init__ files.
+init-import=no
+
+# A regular expression matching the name of dummy variables (i.e. expectedly
+# not used).
+dummy-variables-rgx=_$|dummy
+
+# List of additional names supposed to be defined in builtins. Remember that
+# you should avoid to define new builtins when possible.
+additional-builtins=
+
+# List of strings which can identify a callback function by name. A callback
+# name must start or end with one of those strings.
+callbacks=cb_,_cb
+
+
+[SPELLING]
+
+# Spelling dictionary name. Available dictionaries: none. To make it working
+# install python-enchant package.
+spelling-dict=
+
+# List of comma separated words that should not be checked.
+spelling-ignore-words=
+
+# A path to a file that contains private dictionary; one word per line.
+spelling-private-dict-file=
+
+# Tells whether to store unknown words to indicated private dictionary in
+# --spelling-private-dict-file option instead of raising a message.
+spelling-store-unknown-words=no
+
+
+[LOGGING]
+
+# Logging modules to check that the string format arguments are in logging
+# function parameter format
+logging-modules=logging
+
+
+[TYPECHECK]
+
+# Tells whether missing members accessed in mixin class should be ignored. A
+# mixin class is detected if its name ends with "mixin" (case insensitive).
+ignore-mixin-members=yes
+
+# List of module names for which member attributes should not be checked
+# (useful for modules/projects where namespaces are manipulated during runtime
+# and thus existing member attributes cannot be deduced by static analysis
+ignored-modules=
+
+# List of classes names for which member attributes should not be checked
+# (useful for classes with attributes dynamically set).
+ignored-classes=SQLObject
+
+# When zope mode is activated, add a predefined set of Zope acquired attributes
+# to generated-members.
+zope=no
+
+# List of members which are set dynamically and missed by pylint inference
+# system, and so shouldn't trigger E0201 when accessed. Python regular
+# expressions are accepted.
+generated-members=REQUEST,acl_users,aq_parent
+
+
+[CLASSES]
+
+# List of interface methods to ignore, separated by a comma. This is used for
+# instance to not check methods defines in Zope's Interface base class.
+ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by
+
+# List of method names used to declare (i.e. assign) instance attributes.
+defining-attr-methods=__init__,__new__,setUp
+
+# List of valid names for the first argument in a class method.
+valid-classmethod-first-arg=cls
+
+# List of valid names for the first argument in a metaclass class method.
+valid-metaclass-classmethod-first-arg=mcs
+
+# List of member names, which should be excluded from the protected access
+# warning.
+exclude-protected=_asdict,_fields,_replace,_source,_make
+
+
+[IMPORTS]
+
+# Deprecated modules which should not be used, separated by a comma
+deprecated-modules=regsub,TERMIOS,Bastion,rexec
+
+# Create a graph of every (i.e. internal and external) dependencies in the
+# given file (report RP0402 must not be disabled)
+import-graph=
+
+# Create a graph of external dependencies in the given file (report RP0402 must
+# not be disabled)
+ext-import-graph=
+
+# Create a graph of internal dependencies in the given file (report RP0402 must
+# not be disabled)
+int-import-graph=
+
+
+[DESIGN]
+
+# Maximum number of arguments for function / method
+max-args=5
+
+# Argument names that match this expression will be ignored. Default to name
+# with leading underscore
+ignored-argument-names=_.*
+
+# Maximum number of locals for function / method body
+max-locals=15
+
+# Maximum number of return / yield for function / method body
+max-returns=6
+
+# Maximum number of branch for function / method body
+max-branches=12
+
+# Maximum number of statements in function / method body
+max-statements=50
+
+# Maximum number of parents for a class (see R0901).
+max-parents=7
+
+# Maximum number of attributes for a class (see R0902).
+max-attributes=7
+
+# Minimum number of public methods for a class (see R0903).
+min-public-methods=2
+
+# Maximum number of public methods for a class (see R0904).
+max-public-methods=20
+
+
+[EXCEPTIONS]
+
+# Exceptions that will emit a warning when being caught. Defaults to
+# "Exception"
+overgeneral-exceptions=Exception
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index bc088e4c29e26..595124726366d 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -444,7 +444,7 @@ class DecisionTreeParams(Params):
minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.")
maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.")
cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.")
-
+
def __init__(self):
super(DecisionTreeParams, self).__init__()
@@ -460,7 +460,7 @@ def __init__(self):
self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.")
#: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.
self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.")
-
+
def setMaxDepth(self, value):
"""
Sets the value of :py:attr:`maxDepth`.
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index ed4d78a2c6788..8a92f6911c24b 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -31,13 +31,15 @@
from pyspark.rdd import RDD, ignore_unicode_prefix
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py
from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector
+from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.stat.distribution import MultivariateGaussian
from pyspark.mllib.util import Saveable, Loader, inherit_doc, JavaLoader, JavaSaveable
from pyspark.streaming import DStream
__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture',
'PowerIterationClusteringModel', 'PowerIterationClustering',
- 'StreamingKMeans', 'StreamingKMeansModel']
+ 'StreamingKMeans', 'StreamingKMeansModel',
+ 'LDA', 'LDAModel']
@inherit_doc
@@ -563,6 +565,68 @@ def predictOnValues(self, dstream):
return dstream.mapValues(lambda x: self._model.predict(x))
+class LDAModel(JavaModelWrapper):
+
+ """ A clustering model derived from the LDA method.
+
+ Latent Dirichlet Allocation (LDA), a topic model designed for text documents.
+ Terminology
+ - "word" = "term": an element of the vocabulary
+ - "token": instance of a term appearing in a document
+ - "topic": multinomial distribution over words representing some concept
+ References:
+ - Original LDA paper (journal version):
+ Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
+
+ >>> from pyspark.mllib.linalg import Vectors
+ >>> from numpy.testing import assert_almost_equal
+ >>> data = [
+ ... [1, Vectors.dense([0.0, 1.0])],
+ ... [2, SparseVector(2, {0: 1.0})],
+ ... ]
+ >>> rdd = sc.parallelize(data)
+ >>> model = LDA.train(rdd, k=2)
+ >>> model.vocabSize()
+ 2
+ >>> topics = model.topicsMatrix()
+ >>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]])
+ >>> assert_almost_equal(topics, topics_expect, 1)
+ """
+
+ def topicsMatrix(self):
+ """Inferred topics, where each topic is represented by a distribution over terms."""
+ return self.call("topicsMatrix").toArray()
+
+ def vocabSize(self):
+ """Vocabulary size (number of terms or terms in the vocabulary)"""
+ return self.call("vocabSize")
+
+
+class LDA(object):
+
+ @classmethod
+ def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0,
+ topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"):
+ """Train a LDA model.
+
+ :param rdd: RDD of data points
+ :param k: Number of clusters you want
+ :param maxIterations: Number of iterations. Default to 20
+ :param docConcentration: Concentration parameter (commonly named "alpha")
+ for the prior placed on documents' distributions over topics ("theta").
+ :param topicConcentration: Concentration parameter (commonly named "beta" or "eta")
+ for the prior placed on topics' distributions over terms.
+ :param seed: Random Seed
+ :param checkpointInterval: Period (in iterations) between checkpoints.
+ :param optimizer: LDAOptimizer used to perform the actual calculation.
+ Currently "em", "online" are supported. Default to "em".
+ """
+ model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations,
+ docConcentration, topicConcentration, seed,
+ checkpointInterval, optimizer)
+ return LDAModel(model)
+
+
def _test():
import doctest
import pyspark.mllib.clustering
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index f21403707e12a..4398ca86f2ec2 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -82,7 +82,7 @@ class RegressionMetrics(JavaModelWrapper):
... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)])
>>> metrics = RegressionMetrics(predictionAndObservations)
>>> metrics.explainedVariance
- 0.95...
+ 8.859...
>>> metrics.meanAbsoluteError
0.5...
>>> metrics.meanSquaredError
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index dca39fa833435..e0816b3e654bc 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -39,6 +39,8 @@
'coalesce',
'countDistinct',
'explode',
+ 'format_number',
+ 'length',
'log2',
'md5',
'monotonicallyIncreasingId',
@@ -47,7 +49,6 @@
'sha1',
'sha2',
'sparkPartitionId',
- 'strlen',
'struct',
'udf',
'when']
@@ -506,14 +507,28 @@ def sparkPartitionId():
@ignore_unicode_prefix
@since(1.5)
-def strlen(col):
- """Calculates the length of a string expression.
+def length(col):
+ """Calculates the length of a string or binary expression.
- >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(strlen('a').alias('length')).collect()
+ >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect()
[Row(length=3)]
"""
sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.strlen(_to_java_column(col)))
+ return Column(sc._jvm.functions.length(_to_java_column(col)))
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def format_number(col, d):
+ """Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
+ and returns the result as a string.
+ :param col: the column name of the numeric value to be formatted
+ :param d: the N decimal places
+ >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect()
+ [Row(v=u'5.0000')]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.format_number(_to_java_column(col), d))
@ignore_unicode_prefix
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index c5c0add49d02c..21225016805bc 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -893,7 +893,8 @@ def test_pipe_functions(self):
self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect)
result = rdd.pipe('cat').collect()
result.sort()
- [self.assertEqual(x, y) for x, y in zip(data, result)]
+ for x, y in zip(data, result):
+ self.assertEqual(x, y)
self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect)
self.assertEqual([], rdd.pipe('grep 4').collect())
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index ed69c42dcb825..e0beafe710079 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.analysis
+import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}
@@ -114,8 +115,10 @@ object FunctionRegistry {
expression[Log2]("log2"),
expression[Pow]("pow"),
expression[Pow]("power"),
+ expression[Pmod]("pmod"),
expression[UnaryPositive]("positive"),
expression[Rint]("rint"),
+ expression[Round]("round"),
expression[ShiftLeft]("shiftleft"),
expression[ShiftRight]("shiftright"),
expression[ShiftRightUnsigned]("shiftrightunsigned"),
@@ -149,11 +152,12 @@ object FunctionRegistry {
expression[Base64]("base64"),
expression[Encode]("encode"),
expression[Decode]("decode"),
- expression[StringInstr]("instr"),
+ expression[FormatNumber]("format_number"),
expression[Lower]("lcase"),
expression[Lower]("lower"),
- expression[StringLength]("length"),
+ expression[Length]("length"),
expression[Levenshtein]("levenshtein"),
+ expression[StringInstr]("instr"),
expression[StringLocate]("locate"),
expression[StringLPad]("lpad"),
expression[StringTrimLeft]("ltrim"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 8cb71995eb818..50db7d21f01ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -214,19 +214,6 @@ object HiveTypeCoercion {
}
Union(newLeft, newRight)
-
- // Also widen types for BinaryOperator.
- case q: LogicalPlan => q transformExpressions {
- // Skip nodes who's children have not been resolved yet.
- case e if !e.childrenResolved => e
-
- case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
- findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType =>
- val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
- val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
- b.makeCopy(Array(newLeft, newRight))
- }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
- }
}
}
@@ -439,6 +426,12 @@ object HiveTypeCoercion {
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
)
+ case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ Cast(
+ Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+ DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+ )
+
// When we compare 2 decimal types with different precisions, cast them to the smallest
// common precision.
case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
@@ -672,20 +665,44 @@ object HiveTypeCoercion {
}
/**
- * Casts types according to the expected input types for Expressions that have the trait
- * [[ExpectsInputTypes]].
+ * Casts types according to the expected input types for [[Expression]]s.
*/
object ImplicitTypeCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- case e: ExpectsInputTypes if (e.inputTypes.nonEmpty) =>
+ case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
+ findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType =>
+ if (b.inputType.acceptsType(commonType)) {
+ // If the expression accepts the tightest common type, cast to that.
+ val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
+ val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
+ b.withNewChildren(Seq(newLeft, newRight))
+ } else {
+ // Otherwise, don't do anything with the expression.
+ b
+ }
+ }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
+
+ case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty =>
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
// If we cannot do the implicit cast, just use the original input.
implicitCast(in, expected).getOrElse(in)
}
e.withNewChildren(children)
+
+ case e: ExpectsInputTypes if e.inputTypes.nonEmpty =>
+ // Convert NullType into some specific target type for ExpectsInputTypes that don't do
+ // general implicit casting.
+ val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
+ if (in.dataType == NullType && !expected.acceptsType(NullType)) {
+ Literal.create(null, expected.defaultConcreteType)
+ } else {
+ in
+ }
+ }
+ e.withNewChildren(children)
}
/**
@@ -702,27 +719,22 @@ object HiveTypeCoercion {
@Nullable val ret: Expression = (inType, expectedType) match {
// If the expected type is already a parent of the input type, no need to cast.
- case _ if expectedType.isSameType(inType) => e
+ case _ if expectedType.acceptsType(inType) => e
// Cast null type (usually from null literals) into target types
case (NullType, target) => Cast(e, target.defaultConcreteType)
- // If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is
- // already a number, leave it as is.
- case (_: NumericType, NumericType) => e
-
// If the function accepts any numeric type and the input is a string, we follow the hive
// convention and cast that input into a double
case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType)
- // Implicit cast among numeric types
+ // Implicit cast among numeric types. When we reach here, input type is not acceptable.
+
// If input is a numeric type but not decimal, and we expect a decimal type,
// cast the input to unlimited precision decimal.
- case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] =>
- Cast(e, DecimalType.Unlimited)
+ case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited)
// For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long
- case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target)
- case (_: NumericType, target: NumericType) => e
+ case (_: NumericType, target: NumericType) => Cast(e, target)
// Implicit cast between date time types
case (DateType, TimestampType) => Cast(e, TimestampType)
@@ -736,15 +748,9 @@ object HiveTypeCoercion {
case (StringType, BinaryType) => Cast(e, BinaryType)
case (any, StringType) if any != StringType => Cast(e, StringType)
- // Type collection.
- // First see if we can find our input type in the type collection. If we can, then just
- // use the current expression; otherwise, find the first one we can implicitly cast.
- case (_, TypeCollection(types)) =>
- if (types.exists(_.isSameType(inType))) {
- e
- } else {
- types.flatMap(implicitCast(e, _)).headOption.orNull
- }
+ // When we reach here, input type is not acceptable for any types in this type collection,
+ // try to find the first one we can implicitly cast.
+ case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, _)).headOption.orNull
// Else, just return the same input expression
case _ => null
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
index 3eb0eb195c80d..ded89e85dea79 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
@@ -19,10 +19,15 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.types.AbstractDataType
-
+import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.ImplicitTypeCasts
/**
* An trait that gets mixin to define the expected input types of an expression.
+ *
+ * This trait is typically used by operator expressions (e.g. [[Add]], [[Subtract]]) to define
+ * expected input types without any implicit casting.
+ *
+ * Most function expressions (e.g. [[Substring]] should extends [[ImplicitCastInputTypes]]) instead.
*/
trait ExpectsInputTypes { self: Expression =>
@@ -40,7 +45,7 @@ trait ExpectsInputTypes { self: Expression =>
val mismatches = children.zip(inputTypes).zipWithIndex.collect {
case ((child, expected), idx) if !expected.acceptsType(child.dataType) =>
s"argument ${idx + 1} is expected to be of type ${expected.simpleString}, " +
- s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}."
+ s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}."
}
if (mismatches.isEmpty) {
@@ -50,3 +55,11 @@ trait ExpectsInputTypes { self: Expression =>
}
}
}
+
+
+/**
+ * A mixin for the analyzer to perform implicit type casting using [[ImplicitTypeCasts]].
+ */
+trait ImplicitCastInputTypes extends ExpectsInputTypes { self: Expression =>
+ // No other methods
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 54ec10444c4f3..a655cc8e48ae1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -24,8 +24,20 @@ import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// This file defines the basic expression abstract classes in Catalyst, including:
+// Expression: the base expression abstract class
+// LeafExpression
+// UnaryExpression
+// BinaryExpression
+// BinaryOperator
+//
+// For details, see their classdocs.
+////////////////////////////////////////////////////////////////////////////////////////////////////
/**
+ * An expression in Catalyst.
+ *
* If an expression wants to be exposed in the function registry (so users can call it with
* "name(arguments...)", the concrete implementation must be a case class whose constructor
* arguments are all Expressions types.
@@ -49,9 +61,15 @@ abstract class Expression extends TreeNode[Expression] {
def foldable: Boolean = false
/**
- * Returns true when the current expression always return the same result for fixed input values.
+ * Returns true when the current expression always return the same result for fixed inputs from
+ * children.
+ *
+ * Note that this means that an expression should be considered as non-deterministic if:
+ * - if it relies on some mutable internal state, or
+ * - if it relies on some implicit input that is not part of the children expression list.
+ *
+ * An example would be `SparkPartitionID` that relies on the partition id returned by TaskContext.
*/
- // TODO: Need to define explicit input values vs implicit input values.
def deterministic: Boolean = true
def nullable: Boolean
@@ -169,8 +187,10 @@ abstract class Expression extends TreeNode[Expression] {
/**
* A leaf expression, i.e. one without any child expressions.
*/
-abstract class LeafExpression extends Expression with trees.LeafNode[Expression] {
+abstract class LeafExpression extends Expression {
self: Product =>
+
+ def children: Seq[Expression] = Nil
}
@@ -178,9 +198,13 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
* An expression with one input and one output. The output is by default evaluated to null
* if the input is evaluated to null.
*/
-abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
+abstract class UnaryExpression extends Expression {
self: Product =>
+ def child: Expression
+
+ override def children: Seq[Expression] = child :: Nil
+
override def foldable: Boolean = child.foldable
override def nullable: Boolean = child.nullable
@@ -253,9 +277,14 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
* An expression with two inputs and one output. The output is by default evaluated to null
* if any input is evaluated to null.
*/
-abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
+abstract class BinaryExpression extends Expression {
self: Product =>
+ def left: Expression
+ def right: Expression
+
+ override def children: Seq[Expression] = Seq(left, right)
+
override def foldable: Boolean = left.foldable && right.foldable
override def nullable: Boolean = left.nullable || right.nullable
@@ -335,15 +364,39 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
/**
- * An expression that has two inputs that are expected to the be same type. If the two inputs have
- * different types, the analyzer will find the tightest common type and do the proper type casting.
+ * A [[BinaryExpression]] that is an operator, with two properties:
+ *
+ * 1. The string representation is "x symbol y", rather than "funcName(x, y)".
+ * 2. Two inputs are expected to the be same type. If the two inputs have different types,
+ * the analyzer will find the tightest common type and do the proper type casting.
*/
-abstract class BinaryOperator extends BinaryExpression {
+abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
self: Product =>
+ /**
+ * Expected input type from both left/right child expressions, similar to the
+ * [[ImplicitCastInputTypes]] trait.
+ */
+ def inputType: AbstractDataType
+
def symbol: String
override def toString: String = s"($left $symbol $right)"
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ // First check whether left and right have the same type, then check if the type is acceptable.
+ if (left.dataType != right.dataType) {
+ TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
+ s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
+ } else if (!inputType.acceptsType(left.dataType)) {
+ TypeCheckResult.TypeCheckFailure(s"'$prettyString' accepts ${inputType.simpleString} type," +
+ s" not ${left.dataType.simpleString}")
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 6fb3343bb63f2..22687acd68a97 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -29,7 +29,7 @@ case class ScalaUDF(
function: AnyRef,
dataType: DataType,
children: Seq[Expression],
- inputTypes: Seq[DataType] = Nil) extends Expression with ExpectsInputTypes {
+ inputTypes: Seq[DataType] = Nil) extends Expression with ImplicitCastInputTypes {
override def nullable: Boolean = true
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 8476af4a5d8d6..382cbe3b84a07 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -18,23 +18,19 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
-abstract class UnaryArithmetic extends UnaryExpression {
- self: Product =>
+
+case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
override def dataType: DataType = child.dataType
-}
-case class UnaryMinus(child: Expression) extends UnaryArithmetic {
override def toString: String = s"-$child"
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForNumericExpr(child.dataType, "operator -")
-
private lazy val numeric = TypeUtils.getNumeric(dataType)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
@@ -45,9 +41,13 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {
protected override def nullSafeEval(input: Any): Any = numeric.negate(input)
}
-case class UnaryPositive(child: Expression) extends UnaryArithmetic {
+case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def prettyName: String = "positive"
+ override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+
+ override def dataType: DataType = child.dataType
+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
defineCodeGen(ctx, ev, c => c)
@@ -57,9 +57,11 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic {
/**
* A function that get the absolute value of the numeric value.
*/
-case class Abs(child: Expression) extends UnaryArithmetic {
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForNumericExpr(child.dataType, "function abs")
+case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+
+ override def dataType: DataType = child.dataType
private lazy val numeric = TypeUtils.getNumeric(dataType)
@@ -71,18 +73,6 @@ abstract class BinaryArithmetic extends BinaryOperator {
override def dataType: DataType = left.dataType
- override def checkInputDataTypes(): TypeCheckResult = {
- if (left.dataType != right.dataType) {
- TypeCheckResult.TypeCheckFailure(
- s"differing types in ${this.getClass.getSimpleName} " +
- s"(${left.dataType} and ${right.dataType}).")
- } else {
- checkTypesInternal(dataType)
- }
- }
-
- protected def checkTypesInternal(t: DataType): TypeCheckResult
-
/** Name of the function for this expression on a [[Decimal]] type. */
def decimalMethod: String =
sys.error("BinaryArithmetics must override either decimalMethod or genCode")
@@ -104,62 +94,61 @@ private[sql] object BinaryArithmetic {
}
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
+
+ override def inputType: AbstractDataType = NumericType
+
override def symbol: String = "+"
override def decimalMethod: String = "$plus"
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
- protected def checkTypesInternal(t: DataType) =
- TypeUtils.checkForNumericExpr(t, "operator " + symbol)
-
private lazy val numeric = TypeUtils.getNumeric(dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2)
}
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
+
+ override def inputType: AbstractDataType = NumericType
+
override def symbol: String = "-"
override def decimalMethod: String = "$minus"
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
- protected def checkTypesInternal(t: DataType) =
- TypeUtils.checkForNumericExpr(t, "operator " + symbol)
-
private lazy val numeric = TypeUtils.getNumeric(dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2)
}
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
+
+ override def inputType: AbstractDataType = NumericType
+
override def symbol: String = "*"
override def decimalMethod: String = "$times"
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
- protected def checkTypesInternal(t: DataType) =
- TypeUtils.checkForNumericExpr(t, "operator " + symbol)
-
private lazy val numeric = TypeUtils.getNumeric(dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
}
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
+
+ override def inputType: AbstractDataType = NumericType
+
override def symbol: String = "/"
override def decimalMethod: String = "$div"
-
override def nullable: Boolean = true
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
- protected def checkTypesInternal(t: DataType) =
- TypeUtils.checkForNumericExpr(t, "operator " + symbol)
-
private lazy val div: (Any, Any) => Any = dataType match {
case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot
@@ -215,17 +204,16 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
}
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
+
+ override def inputType: AbstractDataType = NumericType
+
override def symbol: String = "%"
override def decimalMethod: String = "remainder"
-
override def nullable: Boolean = true
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
- protected def checkTypesInternal(t: DataType) =
- TypeUtils.checkForNumericExpr(t, "operator " + symbol)
-
private lazy val integral = dataType match {
case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]]
@@ -281,10 +269,11 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
}
case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
- override def nullable: Boolean = left.nullable && right.nullable
+ // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.
- protected def checkTypesInternal(t: DataType) =
- TypeUtils.checkForOrderingExpr(t, "function maxOf")
+ override def inputType: AbstractDataType = TypeCollection.Ordered
+
+ override def nullable: Boolean = left.nullable && right.nullable
private lazy val ordering = TypeUtils.getOrdering(dataType)
@@ -331,14 +320,14 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
override def symbol: String = "max"
- override def prettyName: String = symbol
}
case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
- override def nullable: Boolean = left.nullable && right.nullable
+ // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.
- protected def checkTypesInternal(t: DataType) =
- TypeUtils.checkForOrderingExpr(t, "function minOf")
+ override def inputType: AbstractDataType = TypeCollection.Ordered
+
+ override def nullable: Boolean = left.nullable && right.nullable
private lazy val ordering = TypeUtils.getOrdering(dataType)
@@ -385,5 +374,98 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
override def symbol: String = "min"
- override def prettyName: String = symbol
+}
+
+case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
+
+ override def toString: String = s"pmod($left, $right)"
+
+ override def symbol: String = "pmod"
+
+ protected def checkTypesInternal(t: DataType) =
+ TypeUtils.checkForNumericExpr(t, "pmod")
+
+ override def inputType: AbstractDataType = NumericType
+
+ protected override def nullSafeEval(left: Any, right: Any) =
+ dataType match {
+ case IntegerType => pmod(left.asInstanceOf[Int], right.asInstanceOf[Int])
+ case LongType => pmod(left.asInstanceOf[Long], right.asInstanceOf[Long])
+ case ShortType => pmod(left.asInstanceOf[Short], right.asInstanceOf[Short])
+ case ByteType => pmod(left.asInstanceOf[Byte], right.asInstanceOf[Byte])
+ case FloatType => pmod(left.asInstanceOf[Float], right.asInstanceOf[Float])
+ case DoubleType => pmod(left.asInstanceOf[Double], right.asInstanceOf[Double])
+ case _: DecimalType => pmod(left.asInstanceOf[Decimal], right.asInstanceOf[Decimal])
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
+ dataType match {
+ case dt: DecimalType =>
+ val decimalAdd = "$plus"
+ s"""
+ ${ctx.javaType(dataType)} r = $eval1.remainder($eval2);
+ if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
+ ${ev.primitive} = (r.$decimalAdd($eval2)).remainder($eval2);
+ } else {
+ ${ev.primitive} = r;
+ }
+ """
+ // byte and short are casted into int when add, minus, times or divide
+ case ByteType | ShortType =>
+ s"""
+ ${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % $eval2);
+ if (r < 0) {
+ ${ev.primitive} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2);
+ } else {
+ ${ev.primitive} = r;
+ }
+ """
+ case _ =>
+ s"""
+ ${ctx.javaType(dataType)} r = $eval1 % $eval2;
+ if (r < 0) {
+ ${ev.primitive} = (r + $eval2) % $eval2;
+ } else {
+ ${ev.primitive} = r;
+ }
+ """
+ }
+ })
+ }
+
+ private def pmod(a: Int, n: Int): Int = {
+ val r = a % n
+ if (r < 0) {(r + n) % n} else r
+ }
+
+ private def pmod(a: Long, n: Long): Long = {
+ val r = a % n
+ if (r < 0) {(r + n) % n} else r
+ }
+
+ private def pmod(a: Byte, n: Byte): Byte = {
+ val r = a % n
+ if (r < 0) {((r + n) % n).toByte} else r.toByte
+ }
+
+ private def pmod(a: Double, n: Double): Double = {
+ val r = a % n
+ if (r < 0) {(r + n) % n} else r
+ }
+
+ private def pmod(a: Short, n: Short): Short = {
+ val r = a % n
+ if (r < 0) {((r + n) % n).toShort} else r.toShort
+ }
+
+ private def pmod(a: Float, n: Float): Float = {
+ val r = a % n
+ if (r < 0) {(r + n) % n} else r
+ }
+
+ private def pmod(a: Decimal, n: Decimal): Decimal = {
+ val r = a % n
+ if (r.compare(Decimal(0)) < 0) {(r + n) % n} else r
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
index 2d47124d247e7..a1e48c4210877 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
@@ -17,9 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -29,10 +27,10 @@ import org.apache.spark.sql.types._
* Code generation inherited from BinaryArithmetic.
*/
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
- override def symbol: String = "&"
- protected def checkTypesInternal(t: DataType) =
- TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
+ override def inputType: AbstractDataType = IntegralType
+
+ override def symbol: String = "&"
private lazy val and: (Any, Any) => Any = dataType match {
case ByteType =>
@@ -54,10 +52,10 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
* Code generation inherited from BinaryArithmetic.
*/
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
- override def symbol: String = "|"
- protected def checkTypesInternal(t: DataType) =
- TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
+ override def inputType: AbstractDataType = IntegralType
+
+ override def symbol: String = "|"
private lazy val or: (Any, Any) => Any = dataType match {
case ByteType =>
@@ -79,10 +77,10 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
* Code generation inherited from BinaryArithmetic.
*/
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
- override def symbol: String = "^"
- protected def checkTypesInternal(t: DataType) =
- TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
+ override def inputType: AbstractDataType = IntegralType
+
+ override def symbol: String = "^"
private lazy val xor: (Any, Any) => Any = dataType match {
case ByteType =>
@@ -101,11 +99,13 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
/**
* A function that calculates bitwise not(~) of a number.
*/
-case class BitwiseNot(child: Expression) extends UnaryArithmetic {
- override def toString: String = s"~$child"
+case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes {
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~")
+ override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)
+
+ override def dataType: DataType = child.dataType
+
+ override def toString: String = s"~$child"
private lazy val not: (Any) => Any = dataType match {
case ByteType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 9f6329bbda4ec..328d635de8743 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -56,6 +56,18 @@ class CodeGenContext {
*/
val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]()
+ /**
+ * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a
+ * 3-tuple: java type, variable name, code to init it.
+ * They will be kept as member variables in generated classes like `SpecificProjection`.
+ */
+ val mutableStates: mutable.ArrayBuffer[(String, String, String)] =
+ mutable.ArrayBuffer.empty[(String, String, String)]
+
+ def addMutableState(javaType: String, variableName: String, initialValue: String): Unit = {
+ mutableStates += ((javaType, variableName, initialValue))
+ }
+
val stringType: String = classOf[UTF8String].getName
val decimalType: String = classOf[Decimal].getName
@@ -203,7 +215,10 @@ class CodeGenContext {
def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))
}
-
+/**
+ * A wrapper for generated class, defines a `generate` method so that we can pass extra objects
+ * into generated class.
+ */
abstract class GeneratedClass {
def generate(expressions: Array[Expression]): Any
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index addb8023d9c0b..71e47d4f9b620 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -46,6 +46,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
"""
}.mkString("\n")
+ val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) =>
+ s"private $javaType $variableName = $initialValue;"
+ }.mkString("\n ")
val code = s"""
public Object generate($exprType[] expr) {
return new SpecificProjection(expr);
@@ -55,6 +58,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
private $exprType[] expressions = null;
private $mutableRowType mutableRow = null;
+ $mutableStates
public SpecificProjection($exprType[] expr) {
expressions = expr;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index d05dfc108e63a..856ff9f1f96f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -46,30 +46,47 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = {
val ctx = newCodeGenContext()
- val comparisons = ordering.zipWithIndex.map { case (order, i) =>
- val evalA = order.child.gen(ctx)
- val evalB = order.child.gen(ctx)
+ val comparisons = ordering.map { order =>
+ val eval = order.child.gen(ctx)
val asc = order.direction == Ascending
+ val isNullA = ctx.freshName("isNullA")
+ val primitiveA = ctx.freshName("primitiveA")
+ val isNullB = ctx.freshName("isNullB")
+ val primitiveB = ctx.freshName("primitiveB")
s"""
i = a;
- ${evalA.code}
+ boolean $isNullA;
+ ${ctx.javaType(order.child.dataType)} $primitiveA;
+ {
+ ${eval.code}
+ $isNullA = ${eval.isNull};
+ $primitiveA = ${eval.primitive};
+ }
i = b;
- ${evalB.code}
- if (${evalA.isNull} && ${evalB.isNull}) {
+ boolean $isNullB;
+ ${ctx.javaType(order.child.dataType)} $primitiveB;
+ {
+ ${eval.code}
+ $isNullB = ${eval.isNull};
+ $primitiveB = ${eval.primitive};
+ }
+ if ($isNullA && $isNullB) {
// Nothing
- } else if (${evalA.isNull}) {
+ } else if ($isNullA) {
return ${if (order.direction == Ascending) "-1" else "1"};
- } else if (${evalB.isNull}) {
+ } else if ($isNullB) {
return ${if (order.direction == Ascending) "1" else "-1"};
} else {
- int comp = ${ctx.genComp(order.child.dataType, evalA.primitive, evalB.primitive)};
+ int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)};
if (comp != 0) {
return ${if (asc) "comp" else "-comp"};
}
}
"""
}.mkString("\n")
-
+ val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) =>
+ s"private $javaType $variableName = $initialValue;"
+ }.mkString("\n ")
val code = s"""
public SpecificOrdering generate($exprType[] expr) {
return new SpecificOrdering(expr);
@@ -78,6 +95,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
class SpecificOrdering extends ${classOf[BaseOrdering].getName} {
private $exprType[] expressions = null;
+ $mutableStates
public SpecificOrdering($exprType[] expr) {
expressions = expr;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
index 274a42cb69087..9e5a745d512e9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -40,6 +40,9 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
protected def create(predicate: Expression): ((InternalRow) => Boolean) = {
val ctx = newCodeGenContext()
val eval = predicate.gen(ctx)
+ val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) =>
+ s"private $javaType $variableName = $initialValue;"
+ }.mkString("\n ")
val code = s"""
public SpecificPredicate generate($exprType[] expr) {
return new SpecificPredicate(expr);
@@ -47,6 +50,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
class SpecificPredicate extends ${classOf[Predicate].getName} {
private final $exprType[] expressions;
+ $mutableStates
public SpecificPredicate($exprType[] expr) {
expressions = expr;
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 3c7ee9cc16599..3e5ca308dc31d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -151,6 +151,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
s"""if (!nullBits[$i]) arr[$i] = c$i;"""
}.mkString("\n ")
+ val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) =>
+ s"private $javaType $variableName = $initialValue;"
+ }.mkString("\n ")
+
val code = s"""
public SpecificProjection generate($exprType[] expr) {
return new SpecificProjection(expr);
@@ -158,6 +162,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
class SpecificProjection extends ${classOf[BaseProject].getName} {
private $exprType[] expressions = null;
+ $mutableStates
public SpecificProjection($exprType[] expr) {
expressions = expr;
@@ -165,65 +170,65 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
@Override
public Object apply(Object r) {
- return new SpecificRow(expressions, (InternalRow) r);
+ return new SpecificRow((InternalRow) r);
}
- }
- final class SpecificRow extends ${classOf[MutableRow].getName} {
+ final class SpecificRow extends ${classOf[MutableRow].getName} {
- $columns
+ $columns
- public SpecificRow($exprType[] expressions, InternalRow i) {
- $initColumns
- }
+ public SpecificRow(InternalRow i) {
+ $initColumns
+ }
- public int length() { return ${expressions.length};}
- protected boolean[] nullBits = new boolean[${expressions.length}];
- public void setNullAt(int i) { nullBits[i] = true; }
- public boolean isNullAt(int i) { return nullBits[i]; }
+ public int length() { return ${expressions.length};}
+ protected boolean[] nullBits = new boolean[${expressions.length}];
+ public void setNullAt(int i) { nullBits[i] = true; }
+ public boolean isNullAt(int i) { return nullBits[i]; }
- public Object get(int i) {
- if (isNullAt(i)) return null;
- switch (i) {
- $getCases
+ public Object get(int i) {
+ if (isNullAt(i)) return null;
+ switch (i) {
+ $getCases
+ }
+ return null;
}
- return null;
- }
- public void update(int i, Object value) {
- if (value == null) {
- setNullAt(i);
- return;
+ public void update(int i, Object value) {
+ if (value == null) {
+ setNullAt(i);
+ return;
+ }
+ nullBits[i] = false;
+ switch (i) {
+ $updateCases
+ }
}
- nullBits[i] = false;
- switch (i) {
- $updateCases
+ $specificAccessorFunctions
+ $specificMutatorFunctions
+
+ @Override
+ public int hashCode() {
+ int result = 37;
+ $hashUpdates
+ return result;
}
- }
- $specificAccessorFunctions
- $specificMutatorFunctions
-
- @Override
- public int hashCode() {
- int result = 37;
- $hashUpdates
- return result;
- }
- @Override
- public boolean equals(Object other) {
- if (other instanceof SpecificRow) {
- SpecificRow row = (SpecificRow) other;
- $columnChecks
- return true;
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof SpecificRow) {
+ SpecificRow row = (SpecificRow) other;
+ $columnChecks
+ return true;
+ }
+ return super.equals(other);
}
- return super.equals(other);
- }
- @Override
- public InternalRow copy() {
- Object[] arr = new Object[${expressions.length}];
- ${copyColumns}
- return new ${classOf[GenericInternalRow].getName}(arr);
+ @Override
+ public InternalRow copy() {
+ Object[] arr = new Object[${expressions.length}];
+ ${copyColumns}
+ return new ${classOf[GenericInternalRow].getName}(arr);
+ }
}
}
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
index c7f039ede26b3..9162b73fe56eb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
@@ -35,8 +35,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
TypeCheckResult.TypeCheckFailure(
s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
} else if (trueValue.dataType != falseValue.dataType) {
- TypeCheckResult.TypeCheckFailure(
- s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).")
+ TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
+ s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
} else {
TypeCheckResult.TypeCheckSuccess
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index c31890e27fb54..a7ad452ef4943 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -19,8 +19,10 @@ package org.apache.spark.sql.catalyst.expressions
import java.{lang => jl}
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure}
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -55,7 +57,7 @@ abstract class LeafMathExpression(c: Double, name: String)
* @param name The short name of the function
*/
abstract class UnaryMathExpression(f: Double => Double, name: String)
- extends UnaryExpression with Serializable with ExpectsInputTypes { self: Product =>
+ extends UnaryExpression with Serializable with ImplicitCastInputTypes { self: Product =>
override def inputTypes: Seq[DataType] = Seq(DoubleType)
override def dataType: DataType = DoubleType
@@ -89,7 +91,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
* @param name The short name of the function
*/
abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
- extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product =>
+ extends BinaryExpression with Serializable with ImplicitCastInputTypes { self: Product =>
override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
@@ -174,7 +176,7 @@ object Factorial {
)
}
-case class Factorial(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[DataType] = Seq(IntegerType)
@@ -251,7 +253,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia
}
case class Bin(child: Expression)
- extends UnaryExpression with Serializable with ExpectsInputTypes {
+ extends UnaryExpression with Serializable with ImplicitCastInputTypes {
override def inputTypes: Seq[DataType] = Seq(LongType)
override def dataType: DataType = StringType
@@ -285,7 +287,7 @@ object Hex {
* Otherwise if the number is a STRING, it converts each character into its hex representation
* and returns the resulting STRING. Negative numbers would be treated as two's complement.
*/
-case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
// TODO: Create code-gen version.
override def inputTypes: Seq[AbstractDataType] =
@@ -329,7 +331,7 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes
* Performs the inverse operation of HEX.
* Resulting characters are returned as a byte array.
*/
-case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
// TODO: Create code-gen version.
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
@@ -416,7 +418,7 @@ case class Pow(left: Expression, right: Expression)
* @param right number of bits to left shift.
*/
case class ShiftLeft(left: Expression, right: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(IntegerType, LongType), IntegerType)
@@ -442,7 +444,7 @@ case class ShiftLeft(left: Expression, right: Expression)
* @param right number of bits to left shift.
*/
case class ShiftRight(left: Expression, right: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(IntegerType, LongType), IntegerType)
@@ -468,7 +470,7 @@ case class ShiftRight(left: Expression, right: Expression)
* @param right the number of bits to right shift.
*/
case class ShiftRightUnsigned(left: Expression, right: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(IntegerType, LongType), IntegerType)
@@ -520,3 +522,202 @@ case class Logarithm(left: Expression, right: Expression)
"""
}
}
+
+/**
+ * Round the `child`'s result to `scale` decimal place when `scale` >= 0
+ * or round at integral part when `scale` < 0.
+ * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30.
+ *
+ * Child of IntegralType would eval to itself when `scale` >= 0.
+ * Child of FractionalType whose value is NaN or Infinite would always eval to itself.
+ *
+ * Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]],
+ * which leads to scale update in DecimalType's [[PrecisionInfo]]
+ *
+ * @param child expr to be round, all [[NumericType]] is allowed as Input
+ * @param scale new scale to be round to, this should be a constant int at runtime
+ */
+case class Round(child: Expression, scale: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ import BigDecimal.RoundingMode.HALF_UP
+
+ def this(child: Expression) = this(child, Literal(0))
+
+ override def left: Expression = child
+ override def right: Expression = scale
+
+ // round of Decimal would eval to null if it fails to `changePrecision`
+ override def nullable: Boolean = true
+
+ override def foldable: Boolean = child.foldable
+
+ override lazy val dataType: DataType = child.dataType match {
+ // if the new scale is bigger which means we are scaling up,
+ // keep the original scale as `Decimal` does
+ case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else _scale)
+ case t => t
+ }
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ super.checkInputDataTypes() match {
+ case TypeCheckSuccess =>
+ if (scale.foldable) {
+ TypeCheckSuccess
+ } else {
+ TypeCheckFailure("Only foldable Expression is allowed for scale arguments")
+ }
+ case f => f
+ }
+ }
+
+ // Avoid repeated evaluation since `scale` is a constant int,
+ // avoid unnecessary `child` evaluation in both codegen and non-codegen eval
+ // by checking if scaleV == null as well.
+ private lazy val scaleV: Any = scale.eval(EmptyRow)
+ private lazy val _scale: Int = scaleV.asInstanceOf[Int]
+
+ override def eval(input: InternalRow): Any = {
+ if (scaleV == null) { // if scale is null, no need to eval its child at all
+ null
+ } else {
+ val evalE = child.eval(input)
+ if (evalE == null) {
+ null
+ } else {
+ nullSafeEval(evalE)
+ }
+ }
+ }
+
+ // not overriding since _scale is a constant int at runtime
+ def nullSafeEval(input1: Any): Any = {
+ child.dataType match {
+ case _: DecimalType =>
+ val decimal = input1.asInstanceOf[Decimal]
+ if (decimal.changePrecision(decimal.precision, _scale)) decimal else null
+ case ByteType =>
+ BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte
+ case ShortType =>
+ BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_UP).toShort
+ case IntegerType =>
+ BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt
+ case LongType =>
+ BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong
+ case FloatType =>
+ val f = input1.asInstanceOf[Float]
+ if (f.isNaN || f.isInfinite) {
+ f
+ } else {
+ BigDecimal(f).setScale(_scale, HALF_UP).toFloat
+ }
+ case DoubleType =>
+ val d = input1.asInstanceOf[Double]
+ if (d.isNaN || d.isInfinite) {
+ d
+ } else {
+ BigDecimal(d).setScale(_scale, HALF_UP).toDouble
+ }
+ }
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val ce = child.gen(ctx)
+
+ val evaluationCode = child.dataType match {
+ case _: DecimalType =>
+ s"""
+ if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) {
+ ${ev.primitive} = ${ce.primitive};
+ } else {
+ ${ev.isNull} = true;
+ }"""
+ case ByteType =>
+ if (_scale < 0) {
+ s"""
+ ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
+ setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();"""
+ } else {
+ s"${ev.primitive} = ${ce.primitive};"
+ }
+ case ShortType =>
+ if (_scale < 0) {
+ s"""
+ ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
+ setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();"""
+ } else {
+ s"${ev.primitive} = ${ce.primitive};"
+ }
+ case IntegerType =>
+ if (_scale < 0) {
+ s"""
+ ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
+ setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();"""
+ } else {
+ s"${ev.primitive} = ${ce.primitive};"
+ }
+ case LongType =>
+ if (_scale < 0) {
+ s"""
+ ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
+ setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();"""
+ } else {
+ s"${ev.primitive} = ${ce.primitive};"
+ }
+ case FloatType => // if child eval to NaN or Infinity, just return it.
+ if (_scale == 0) {
+ s"""
+ if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){
+ ${ev.primitive} = ${ce.primitive};
+ } else {
+ ${ev.primitive} = Math.round(${ce.primitive});
+ }"""
+ } else {
+ s"""
+ if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){
+ ${ev.primitive} = ${ce.primitive};
+ } else {
+ ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
+ setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
+ }"""
+ }
+ case DoubleType => // if child eval to NaN or Infinity, just return it.
+ if (_scale == 0) {
+ s"""
+ if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){
+ ${ev.primitive} = ${ce.primitive};
+ } else {
+ ${ev.primitive} = Math.round(${ce.primitive});
+ }"""
+ } else {
+ s"""
+ if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){
+ ${ev.primitive} = ${ce.primitive};
+ } else {
+ ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
+ setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
+ }"""
+ }
+ }
+
+ if (scaleV == null) { // if scale is null, no need to eval its child at all
+ s"""
+ boolean ${ev.isNull} = true;
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ """
+ } else {
+ s"""
+ ${ce.code}
+ boolean ${ev.isNull} = ${ce.isNull};
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ $evaluationCode
+ }
+ """
+ }
+ }
+
+ override def prettyName: String = "round"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 3b59cd431b871..a269ec4a1e6dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -31,7 +31,7 @@ import org.apache.spark.unsafe.types.UTF8String
* A function that calculates an MD5 128-bit checksum and returns it as a hex string
* For input of type [[BinaryType]]
*/
-case class Md5(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = StringType
@@ -55,7 +55,7 @@ case class Md5(child: Expression) extends UnaryExpression with ExpectsInputTypes
* the hash length is not one of the permitted values, the return value is NULL.
*/
case class Sha2(left: Expression, right: Expression)
- extends BinaryExpression with Serializable with ExpectsInputTypes {
+ extends BinaryExpression with Serializable with ImplicitCastInputTypes {
override def dataType: DataType = StringType
@@ -118,7 +118,7 @@ case class Sha2(left: Expression, right: Expression)
* A function that calculates a sha1 hash value and returns it as a hex string
* For input of type [[BinaryType]] or [[StringType]]
*/
-case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = StringType
@@ -138,7 +138,7 @@ case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputType
* A function that computes a cyclic redundancy check value and returns it as a bigint
* For input of type [[BinaryType]]
*/
-case class Crc32(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = LongType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index f74fd04619714..aa6c30e2f79f2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -33,12 +33,17 @@ object InterpretedPredicate {
}
}
+
+/**
+ * An [[Expression]] that returns a boolean value.
+ */
trait Predicate extends Expression {
self: Product =>
override def dataType: DataType = BooleanType
}
+
trait PredicateHelper {
protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = {
condition match {
@@ -70,7 +75,10 @@ trait PredicateHelper {
expr.references.subsetOf(plan.outputSet)
}
-case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes {
+
+case class Not(child: Expression)
+ extends UnaryExpression with Predicate with ImplicitCastInputTypes {
+
override def toString: String = s"NOT $child"
override def inputTypes: Seq[DataType] = Seq(BooleanType)
@@ -82,6 +90,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex
}
}
+
/**
* Evaluates to `true` if `list` contains `value`.
*/
@@ -97,6 +106,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}
}
+
/**
* Optimized version of In clause, when all filter values of In clause are
* static.
@@ -112,12 +122,12 @@ case class InSet(child: Expression, hset: Set[Any])
}
}
-case class And(left: Expression, right: Expression)
- extends BinaryExpression with Predicate with ExpectsInputTypes {
- override def toString: String = s"($left && $right)"
+case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {
+
+ override def inputType: AbstractDataType = BooleanType
- override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
+ override def symbol: String = "&&"
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
@@ -161,12 +171,12 @@ case class And(left: Expression, right: Expression)
}
}
-case class Or(left: Expression, right: Expression)
- extends BinaryExpression with Predicate with ExpectsInputTypes {
- override def toString: String = s"($left || $right)"
+case class Or(left: Expression, right: Expression) extends BinaryOperator with Predicate {
- override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
+ override def inputType: AbstractDataType = BooleanType
+
+ override def symbol: String = "||"
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
@@ -210,21 +220,10 @@ case class Or(left: Expression, right: Expression)
}
}
+
abstract class BinaryComparison extends BinaryOperator with Predicate {
self: Product =>
- override def checkInputDataTypes(): TypeCheckResult = {
- if (left.dataType != right.dataType) {
- TypeCheckResult.TypeCheckFailure(
- s"differing types in ${this.getClass.getSimpleName} " +
- s"(${left.dataType} and ${right.dataType}).")
- } else {
- checkTypesInternal(dataType)
- }
- }
-
- protected def checkTypesInternal(t: DataType): TypeCheckResult
-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
if (ctx.isPrimitiveType(left.dataType)) {
// faster version
@@ -235,10 +234,12 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
}
}
+
private[sql] object BinaryComparison {
def unapply(e: BinaryComparison): Option[(Expression, Expression)] = Some((e.left, e.right))
}
+
/** An extractor that matches both standard 3VL equality and null-safe equality. */
private[sql] object Equality {
def unapply(e: BinaryComparison): Option[(Expression, Expression)] = e match {
@@ -248,10 +249,12 @@ private[sql] object Equality {
}
}
+
case class EqualTo(left: Expression, right: Expression) extends BinaryComparison {
- override def symbol: String = "="
- override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess
+ override def inputType: AbstractDataType = AnyDataType
+
+ override def symbol: String = "="
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
if (left.dataType != BinaryType) input1 == input2
@@ -263,13 +266,15 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
}
}
+
case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison {
+
+ override def inputType: AbstractDataType = AnyDataType
+
override def symbol: String = "<=>"
override def nullable: Boolean = false
- override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess
-
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
val input2 = right.eval(input)
@@ -298,44 +303,48 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
}
}
+
case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
- override def symbol: String = "<"
- override protected def checkTypesInternal(t: DataType) =
- TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)
+ override def inputType: AbstractDataType = TypeCollection.Ordered
+
+ override def symbol: String = "<"
private lazy val ordering = TypeUtils.getOrdering(left.dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2)
}
+
case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
- override def symbol: String = "<="
- override protected def checkTypesInternal(t: DataType) =
- TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)
+ override def inputType: AbstractDataType = TypeCollection.Ordered
+
+ override def symbol: String = "<="
private lazy val ordering = TypeUtils.getOrdering(left.dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2)
}
+
case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
- override def symbol: String = ">"
- override protected def checkTypesInternal(t: DataType) =
- TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)
+ override def inputType: AbstractDataType = TypeCollection.Ordered
+
+ override def symbol: String = ">"
private lazy val ordering = TypeUtils.getOrdering(left.dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2)
}
+
case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
- override def symbol: String = ">="
- override protected def checkTypesInternal(t: DataType) =
- TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)
+ override def inputType: AbstractDataType = TypeCollection.Ordered
+
+ override def symbol: String = ">="
private lazy val ordering = TypeUtils.getOrdering(left.dataType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
index 6cdc3000382e2..e10ba55396664 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.TaskContext
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.types.{DataType, DoubleType}
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
@@ -38,11 +39,7 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable {
* Record ID within each partition. By being transient, the Random Number Generator is
* reset every time we serialize and deserialize it.
*/
- @transient protected lazy val partitionId = TaskContext.get() match {
- case null => 0
- case _ => TaskContext.get().partitionId()
- }
- @transient protected lazy val rng = new XORShiftRandom(seed + partitionId)
+ @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.getPartitionId)
override def deterministic: Boolean = false
@@ -61,6 +58,17 @@ case class Rand(seed: Long) extends RDG(seed) {
case IntegerLiteral(s) => s
case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
})
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val rngTerm = ctx.freshName("rng")
+ val className = classOf[XORShiftRandom].getCanonicalName
+ ctx.addMutableState(className, rngTerm,
+ s"new $className($seed + org.apache.spark.TaskContext.getPartitionId())")
+ ev.isNull = "false"
+ s"""
+ final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextDouble();
+ """
+ }
}
/** Generate a random column with i.i.d. gaussian random distribution. */
@@ -73,4 +81,15 @@ case class Randn(seed: Long) extends RDG(seed) {
case IntegerLiteral(s) => s
case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
})
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val rngTerm = ctx.freshName("rng")
+ val className = classOf[XORShiftRandom].getCanonicalName
+ ctx.addMutableState(className, rngTerm,
+ s"new $className($seed + org.apache.spark.TaskContext.getPartitionId())")
+ ev.isNull = "false"
+ s"""
+ final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextGaussian();
+ """
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index f64899c1ed84c..c64afe7b3f19a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
+import java.text.DecimalFormat
import java.util.Locale
import java.util.regex.Pattern
-import org.apache.commons.lang3.StringUtils
-
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -29,7 +28,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-trait StringRegexExpression extends ExpectsInputTypes {
+trait StringRegexExpression extends ImplicitCastInputTypes {
self: BinaryExpression =>
def escape(v: String): String
@@ -105,7 +104,7 @@ case class RLike(left: Expression, right: Expression)
override def toString: String = s"$left RLIKE $right"
}
-trait String2StringExpression extends ExpectsInputTypes {
+trait String2StringExpression extends ImplicitCastInputTypes {
self: UnaryExpression =>
def convert(v: UTF8String): UTF8String
@@ -142,7 +141,7 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx
}
/** A base trait for functions that compare two strings, returning a boolean. */
-trait StringComparison extends ExpectsInputTypes {
+trait StringComparison extends ImplicitCastInputTypes {
self: BinaryExpression =>
def compare(l: UTF8String, r: UTF8String): Boolean
@@ -241,7 +240,7 @@ case class StringTrimRight(child: Expression)
* NOTE: that this is not zero based, but 1-based index. The first character in str has index 1.
*/
case class StringInstr(str: Expression, substr: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = str
override def right: Expression = substr
@@ -265,7 +264,7 @@ case class StringInstr(str: Expression, substr: Expression)
* in given string after position pos.
*/
case class StringLocate(substr: Expression, str: Expression, start: Expression)
- extends Expression with ExpectsInputTypes {
+ extends Expression with ImplicitCastInputTypes {
def this(substr: Expression, str: Expression) = {
this(substr, str, Literal(0))
@@ -306,7 +305,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
* Returns str, left-padded with pad to a length of len.
*/
case class StringLPad(str: Expression, len: Expression, pad: Expression)
- extends Expression with ExpectsInputTypes {
+ extends Expression with ImplicitCastInputTypes {
override def children: Seq[Expression] = str :: len :: pad :: Nil
override def foldable: Boolean = children.forall(_.foldable)
@@ -344,7 +343,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression)
* Returns str, right-padded with pad to a length of len.
*/
case class StringRPad(str: Expression, len: Expression, pad: Expression)
- extends Expression with ExpectsInputTypes {
+ extends Expression with ImplicitCastInputTypes {
override def children: Seq[Expression] = str :: len :: pad :: Nil
override def foldable: Boolean = children.forall(_.foldable)
@@ -413,7 +412,7 @@ case class StringFormat(children: Expression*) extends Expression {
* Returns the string which repeat the given string value n times.
*/
case class StringRepeat(str: Expression, times: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = str
override def right: Expression = times
@@ -447,7 +446,7 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2
/**
* Returns a n spaces string.
*/
-case class StringSpace(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class StringSpace(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(IntegerType)
@@ -467,7 +466,7 @@ case class StringSpace(child: Expression) extends UnaryExpression with ExpectsIn
* Splits str around pat (pattern is a regular expression).
*/
case class StringSplit(str: Expression, pattern: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = str
override def right: Expression = pattern
@@ -488,7 +487,7 @@ case class StringSplit(str: Expression, pattern: Expression)
* Defined for String and Binary types.
*/
case class Substring(str: Expression, pos: Expression, len: Expression)
- extends Expression with ExpectsInputTypes {
+ extends Expression with ImplicitCastInputTypes {
def this(str: Expression, pos: Expression) = {
this(str, pos, Literal(Integer.MAX_VALUE))
@@ -553,17 +552,22 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
}
/**
- * A function that return the length of the given string expression.
+ * A function that return the length of the given string or binary expression.
*/
-case class StringLength(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = IntegerType
- override def inputTypes: Seq[DataType] = Seq(StringType)
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))
- protected override def nullSafeEval(string: Any): Any =
- string.asInstanceOf[UTF8String].numChars
+ protected override def nullSafeEval(value: Any): Any = child.dataType match {
+ case StringType => value.asInstanceOf[UTF8String].numChars
+ case BinaryType => value.asInstanceOf[Array[Byte]].length
+ }
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- defineCodeGen(ctx, ev, c => s"($c).numChars()")
+ child.dataType match {
+ case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()")
+ case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length")
+ }
}
override def prettyName: String = "length"
@@ -573,7 +577,7 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI
* A function that return the Levenshtein distance between the two given strings.
*/
case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression
- with ExpectsInputTypes {
+ with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
@@ -591,7 +595,7 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres
/**
* Returns the numeric value of the first character of str.
*/
-case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType)
@@ -608,7 +612,7 @@ case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTyp
/**
* Converts the argument from binary to a base 64 string.
*/
-case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(BinaryType)
@@ -622,7 +626,7 @@ case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTy
/**
* Converts the argument from a base 64 string to BINARY.
*/
-case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = BinaryType
override def inputTypes: Seq[DataType] = Seq(StringType)
@@ -636,7 +640,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInput
* If either argument is null, the result will also be null.
*/
case class Decode(bin: Expression, charset: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = bin
override def right: Expression = charset
@@ -655,7 +659,7 @@ case class Decode(bin: Expression, charset: Expression)
* If either argument is null, the result will also be null.
*/
case class Encode(value: Expression, charset: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = value
override def right: Expression = charset
@@ -668,3 +672,77 @@ case class Encode(value: Expression, charset: Expression)
}
}
+/**
+ * Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
+ * and returns the result as a string. If D is 0, the result has no decimal point or
+ * fractional part.
+ */
+case class FormatNumber(x: Expression, d: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ override def left: Expression = x
+ override def right: Expression = d
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
+
+ // Associated with the pattern, for the last d value, and we will update the
+ // pattern (DecimalFormat) once the new coming d value differ with the last one.
+ @transient
+ private var lastDValue: Int = -100
+
+ // A cached DecimalFormat, for performance concern, we will change it
+ // only if the d value changed.
+ @transient
+ private val pattern: StringBuffer = new StringBuffer()
+
+ @transient
+ private val numberFormat: DecimalFormat = new DecimalFormat("")
+
+ override def eval(input: InternalRow): Any = {
+ val xObject = x.eval(input)
+ if (xObject == null) {
+ return null
+ }
+
+ val dObject = d.eval(input)
+
+ if (dObject == null || dObject.asInstanceOf[Int] < 0) {
+ return null
+ }
+ val dValue = dObject.asInstanceOf[Int]
+
+ if (dValue != lastDValue) {
+ // construct a new DecimalFormat only if a new dValue
+ pattern.delete(0, pattern.length())
+ pattern.append("#,###,###,###,###,###,##0")
+
+ // decimal place
+ if (dValue > 0) {
+ pattern.append(".")
+
+ var i = 0
+ while (i < dValue) {
+ i += 1
+ pattern.append("0")
+ }
+ }
+ val dFormat = new DecimalFormat(pattern.toString())
+ lastDValue = dValue;
+ numberFormat.applyPattern(dFormat.toPattern())
+ }
+
+ x.dataType match {
+ case ByteType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Byte]))
+ case ShortType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Short]))
+ case FloatType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Float]))
+ case IntegerType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Int]))
+ case LongType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Long]))
+ case DoubleType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Double]))
+ case _: DecimalType =>
+ UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Decimal].toJavaBigDecimal))
+ }
+ }
+
+ override def prettyName: String = "format_number"
+}
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index e911b907e8536..d7077a0ec907a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -291,6 +291,11 @@ abstract class UnaryNode extends LogicalPlan with trees.UnaryNode[LogicalPlan] {
/**
* A logical plan node with a left and right child.
*/
-abstract class BinaryNode extends LogicalPlan with trees.BinaryNode[LogicalPlan] {
+abstract class BinaryNode extends LogicalPlan {
self: Product =>
+
+ def left: LogicalPlan
+ def right: LogicalPlan
+
+ override def children: Seq[LogicalPlan] = Seq(left, right)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 09f6c6b0ec423..16844b2f4b680 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -453,15 +453,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
}
}
-/**
- * A [[TreeNode]] that has two children, [[left]] and [[right]].
- */
-trait BinaryNode[BaseType <: TreeNode[BaseType]] {
- def left: BaseType
- def right: BaseType
-
- def children: Seq[BaseType] = Seq(left, right)
-}
/**
* A [[TreeNode]] with no children.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 3148309a2166f..0103ddcf9cfb7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -32,14 +32,6 @@ object TypeUtils {
}
}
- def checkForBitwiseExpr(t: DataType, caller: String): TypeCheckResult = {
- if (t.isInstanceOf[IntegralType] || t == NullType) {
- TypeCheckResult.TypeCheckSuccess
- } else {
- TypeCheckResult.TypeCheckFailure(s"$caller accepts integral types, not $t")
- }
- }
-
def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = {
if (t.isInstanceOf[AtomicType] || t == NullType) {
TypeCheckResult.TypeCheckSuccess
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index 32f87440b4e37..076d7b5a5118d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -34,32 +34,18 @@ private[sql] abstract class AbstractDataType {
private[sql] def defaultConcreteType: DataType
/**
- * Returns true if this data type is the same type as `other`. This is different that equality
- * as equality will also consider data type parametrization, such as decimal precision.
+ * Returns true if `other` is an acceptable input type for a function that expects this,
+ * possibly abstract DataType.
*
* {{{
* // this should return true
- * DecimalType.isSameType(DecimalType(10, 2))
- *
- * // this should return false
- * NumericType.isSameType(DecimalType(10, 2))
- * }}}
- */
- private[sql] def isSameType(other: DataType): Boolean
-
- /**
- * Returns true if `other` is an acceptable input type for a function that expectes this,
- * possibly abstract, DataType.
- *
- * {{{
- * // this should return true
- * DecimalType.isSameType(DecimalType(10, 2))
+ * DecimalType.acceptsType(DecimalType(10, 2))
*
* // this should return true as well
* NumericType.acceptsType(DecimalType(10, 2))
* }}}
*/
- private[sql] def acceptsType(other: DataType): Boolean = isSameType(other)
+ private[sql] def acceptsType(other: DataType): Boolean
/** Readable string representation for the type. */
private[sql] def simpleString: String
@@ -83,10 +69,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType
- override private[sql] def isSameType(other: DataType): Boolean = false
-
override private[sql] def acceptsType(other: DataType): Boolean =
- types.exists(_.isSameType(other))
+ types.exists(_.acceptsType(other))
override private[sql] def simpleString: String = {
types.map(_.simpleString).mkString("(", " or ", ")")
@@ -96,6 +80,17 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
private[sql] object TypeCollection {
+ /**
+ * Types that can be ordered/compared. In the long run we should probably make this a trait
+ * that can be mixed into each data type, and perhaps create an [[AbstractDataType]].
+ */
+ val Ordered = TypeCollection(
+ BooleanType,
+ ByteType, ShortType, IntegerType, LongType,
+ FloatType, DoubleType, DecimalType,
+ TimestampType, DateType,
+ StringType, BinaryType)
+
def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)
def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match {
@@ -105,6 +100,21 @@ private[sql] object TypeCollection {
}
+/**
+ * An [[AbstractDataType]] that matches any concrete data types.
+ */
+protected[sql] object AnyDataType extends AbstractDataType {
+
+ // Note that since AnyDataType matches any concrete types, defaultConcreteType should never
+ // be invoked.
+ override private[sql] def defaultConcreteType: DataType = throw new UnsupportedOperationException
+
+ override private[sql] def simpleString: String = "any"
+
+ override private[sql] def acceptsType(other: DataType): Boolean = true
+}
+
+
/**
* An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps.
*/
@@ -148,13 +158,11 @@ private[sql] object NumericType extends AbstractDataType {
override private[sql] def simpleString: String = "numeric"
- override private[sql] def isSameType(other: DataType): Boolean = false
-
override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType]
}
-private[sql] object IntegralType {
+private[sql] object IntegralType extends AbstractDataType {
/**
* Enables matching against IntegralType for expressions:
* {{{
@@ -163,6 +171,12 @@ private[sql] object IntegralType {
* }}}
*/
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType]
+
+ override private[sql] def defaultConcreteType: DataType = IntegerType
+
+ override private[sql] def simpleString: String = "integral"
+
+ override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[IntegralType]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index 76ca7a84c1d1a..5094058164b2f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -28,7 +28,7 @@ object ArrayType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)
- override private[sql] def isSameType(other: DataType): Boolean = {
+ override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[ArrayType]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index da83a7f0ba379..2d133eea19fe0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = this
- override private[sql] def isSameType(other: DataType): Boolean = this == other
+ override private[sql] def acceptsType(other: DataType): Boolean = this == other
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index f5bd068d60dc4..a85af9e04aedb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.types
-import java.math.{MathContext, RoundingMode}
-
import org.apache.spark.annotation.DeveloperApi
/**
@@ -138,14 +136,6 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}
def toBigDecimal: BigDecimal = {
- if (decimalVal.ne(null)) {
- decimalVal(MathContext.UNLIMITED)
- } else {
- BigDecimal(longVal, _scale)(MathContext.UNLIMITED)
- }
- }
-
- def toLimitedBigDecimal: BigDecimal = {
if (decimalVal.ne(null)) {
decimalVal
} else {
@@ -273,15 +263,8 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal)
- def / (that: Decimal): Decimal = {
- if (that.isZero) {
- null
- } else {
- // To avoid non-terminating decimal expansion problem, we get scala's BigDecimal with limited
- // precision and scala.
- Decimal(toLimitedBigDecimal / that.toLimitedBigDecimal)
- }
- }
+ def / (that: Decimal): Decimal =
+ if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal)
def % (that: Decimal): Decimal =
if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index a1cafeab1704d..377c75f6e85a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -86,7 +86,7 @@ object DecimalType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = Unlimited
- override private[sql] def isSameType(other: DataType): Boolean = {
+ override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[DecimalType]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
index ddead10bc2171..ac34b642827ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
@@ -71,7 +71,7 @@ object MapType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType)
- override private[sql] def isSameType(other: DataType): Boolean = {
+ override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[MapType]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index b8097403ec3cc..2ef97a427c37e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -307,7 +307,7 @@ object StructType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = new StructType
- override private[sql] def isSameType(other: DataType): Boolean = {
+ override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[StructType]
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 9d0c69a2451d1..f0f17103991ef 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
case class TestFunction(
children: Seq[Expression],
- inputTypes: Seq[AbstractDataType]) extends Expression with ExpectsInputTypes {
+ inputTypes: Seq[AbstractDataType]) extends Expression with ImplicitCastInputTypes {
override def nullable: Boolean = true
override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
override def dataType: DataType = StringType
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 8e0551b23eea6..ed0d20e7de80e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.types.{TypeCollection, StringType}
class ExpressionTypeCheckingSuite extends SparkFunSuite {
@@ -49,13 +49,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
def assertErrorForDifferingTypes(expr: Expression): Unit = {
assertError(expr,
- s"differing types in ${expr.getClass.getSimpleName} (IntegerType and BooleanType).")
+ s"differing types in '${expr.prettyString}'")
}
test("check types for unary arithmetic") {
- assertError(UnaryMinus('stringField), "operator - accepts numeric type")
- assertError(Abs('stringField), "function abs accepts numeric type")
- assertError(BitwiseNot('stringField), "operator ~ accepts integral type")
+ assertError(UnaryMinus('stringField), "expected to be of type numeric")
+ assertError(Abs('stringField), "expected to be of type numeric")
+ assertError(BitwiseNot('stringField), "expected to be of type integral")
}
test("check types for binary arithmetic") {
@@ -78,18 +78,20 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(MaxOf('intField, 'booleanField))
assertErrorForDifferingTypes(MinOf('intField, 'booleanField))
- assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type")
- assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type")
- assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type")
- assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type")
- assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type")
+ assertError(Add('booleanField, 'booleanField), "accepts numeric type")
+ assertError(Subtract('booleanField, 'booleanField), "accepts numeric type")
+ assertError(Multiply('booleanField, 'booleanField), "accepts numeric type")
+ assertError(Divide('booleanField, 'booleanField), "accepts numeric type")
+ assertError(Remainder('booleanField, 'booleanField), "accepts numeric type")
- assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type")
- assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type")
- assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type")
+ assertError(BitwiseAnd('booleanField, 'booleanField), "accepts integral type")
+ assertError(BitwiseOr('booleanField, 'booleanField), "accepts integral type")
+ assertError(BitwiseXor('booleanField, 'booleanField), "accepts integral type")
- assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type")
- assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type")
+ assertError(MaxOf('complexField, 'complexField),
+ s"accepts ${TypeCollection.Ordered.simpleString} type")
+ assertError(MinOf('complexField, 'complexField),
+ s"accepts ${TypeCollection.Ordered.simpleString} type")
}
test("check types for predicates") {
@@ -105,25 +107,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertSuccess(EqualTo('intField, 'booleanField))
assertSuccess(EqualNullSafe('intField, 'booleanField))
- assertError(EqualTo('intField, 'complexField), "differing types")
- assertError(EqualNullSafe('intField, 'complexField), "differing types")
-
+ assertErrorForDifferingTypes(EqualTo('intField, 'complexField))
+ assertErrorForDifferingTypes(EqualNullSafe('intField, 'complexField))
assertErrorForDifferingTypes(LessThan('intField, 'booleanField))
assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField))
assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))
- assertError(
- LessThan('complexField, 'complexField), "operator < accepts non-complex type")
- assertError(
- LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type")
- assertError(
- GreaterThan('complexField, 'complexField), "operator > accepts non-complex type")
- assertError(
- GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type")
+ assertError(LessThan('complexField, 'complexField),
+ s"accepts ${TypeCollection.Ordered.simpleString} type")
+ assertError(LessThanOrEqual('complexField, 'complexField),
+ s"accepts ${TypeCollection.Ordered.simpleString} type")
+ assertError(GreaterThan('complexField, 'complexField),
+ s"accepts ${TypeCollection.Ordered.simpleString} type")
+ assertError(GreaterThanOrEqual('complexField, 'complexField),
+ s"accepts ${TypeCollection.Ordered.simpleString} type")
- assertError(
- If('intField, 'stringField, 'stringField),
+ assertError(If('intField, 'stringField, 'stringField),
"type of predicate expression in If should be boolean")
assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField))
@@ -171,4 +171,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)),
"Odd position only allow foldable and not-null StringType expressions")
}
+
+ test("check types for ROUND") {
+ assertSuccess(Round(Literal(null), Literal(null)))
+ assertSuccess(Round('intField, Literal(1)))
+
+ assertError(Round('intField, 'intField), "Only foldable Expression is allowed")
+ assertError(Round('intField, 'booleanField), "expected to be of type int")
+ assertError(Round('intField, 'complexField), "expected to be of type int")
+ assertError(Round('booleanField, 'intField), "expected to be of type numeric")
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index acb9a433de903..d0fd033b981c8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -194,6 +194,30 @@ class HiveTypeCoercionSuite extends PlanTest {
Project(Seq(Alias(transformed, "a")()), testRelation))
}
+ test("cast NullType for expresions that implement ExpectsInputTypes") {
+ import HiveTypeCoercionSuite._
+
+ ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
+ AnyTypeUnaryExpression(Literal.create(null, NullType)),
+ AnyTypeUnaryExpression(Literal.create(null, NullType)))
+
+ ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
+ NumericTypeUnaryExpression(Literal.create(null, NullType)),
+ NumericTypeUnaryExpression(Literal.create(null, DoubleType)))
+ }
+
+ test("cast NullType for binary operators") {
+ import HiveTypeCoercionSuite._
+
+ ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
+ AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
+ AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)))
+
+ ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
+ NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
+ NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType)))
+ }
+
test("coalesce casts") {
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
Coalesce(Literal(1.0)
@@ -302,3 +326,33 @@ class HiveTypeCoercionSuite extends PlanTest {
)
}
}
+
+
+object HiveTypeCoercionSuite {
+
+ case class AnyTypeUnaryExpression(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes {
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+ override def dataType: DataType = NullType
+ }
+
+ case class NumericTypeUnaryExpression(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes {
+ override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+ override def dataType: DataType = NullType
+ }
+
+ case class AnyTypeBinaryOperator(left: Expression, right: Expression)
+ extends BinaryOperator {
+ override def dataType: DataType = NullType
+ override def inputType: AbstractDataType = AnyDataType
+ override def symbol: String = "anytype"
+ }
+
+ case class NumericTypeBinaryOperator(left: Expression, right: Expression)
+ extends BinaryOperator {
+ override def dataType: DataType = NullType
+ override def inputType: AbstractDataType = NumericType
+ override def symbol: String = "numerictype"
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 6c93698f8017b..e7e5231d32c9e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -21,7 +21,6 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.Decimal
-
class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
/**
@@ -158,4 +157,19 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)),
Array(1.toByte, 2.toByte))
}
+
+ test("pmod") {
+ testNumericDataTypes { convert =>
+ val left = Literal(convert(7))
+ val right = Literal(convert(3))
+ checkEvaluation(Pmod(left, right), convert(1))
+ checkEvaluation(Pmod(Literal.create(null, left.dataType), right), null)
+ checkEvaluation(Pmod(left, Literal.create(null, right.dataType)), null)
+ checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0
+ }
+ checkEvaluation(Pmod(-7, 3), 2)
+ checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005)
+ checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1))
+ checkEvaluation(Pmod(2L, Long.MaxValue), 2)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 7ca9e30b2bcd5..52a874a9d89ef 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.math.BigDecimal.RoundingMode
+
import com.google.common.math.LongMath
import org.apache.spark.SparkFunSuite
@@ -336,4 +338,46 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
null,
create_row(null))
}
+
+ test("round") {
+ val domain = -6 to 6
+ val doublePi: Double = math.Pi
+ val shortPi: Short = 31415
+ val intPi: Int = 314159265
+ val longPi: Long = 31415926535897932L
+ val bdPi: BigDecimal = BigDecimal(31415927L, 7)
+
+ val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142,
+ 3.1416, 3.14159, 3.141593)
+
+ val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++
+ Seq.fill[Short](7)(31415)
+
+ val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300,
+ 314159270) ++ Seq.fill(7)(314159265)
+
+ val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L,
+ 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++
+ Seq.fill(7)(31415926535897932L)
+
+ val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
+ BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
+ BigDecimal(3.141593), BigDecimal(3.1415927))
+
+ domain.zipWithIndex.foreach { case (scale, i) =>
+ checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow)
+ checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow)
+ checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow)
+ checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow)
+ }
+
+ // round_scale > current_scale would result in precision increase
+ // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
+ (0 to 7).foreach { i =>
+ checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow)
+ }
+ (8 to 10).foreach { scale =>
+ checkEvaluation(Round(bdPi, scale), null, EmptyRow)
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
index b19f4ee37a109..5d7763bedf6bd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType}
+import org.apache.spark.sql.types._
class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -216,15 +216,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
- test("length for string") {
- val a = 'a.string.at(0)
- checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef"))
- checkEvaluation(StringLength(a), 5, create_row("abdef"))
- checkEvaluation(StringLength(a), 0, create_row(""))
- checkEvaluation(StringLength(a), null, create_row(null))
- checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
- }
-
test("ascii for string") {
val a = 'a.string.at(0)
checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef"))
@@ -426,4 +417,46 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(
StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1)
}
+
+ test("length for string / binary") {
+ val a = 'a.string.at(0)
+ val b = 'b.binary.at(0)
+ val bytes = Array[Byte](1, 2, 3, 1, 2)
+ val string = "abdef"
+
+ // scalastyle:off
+ // non ascii characters are not allowed in the source code, so we disable the scalastyle.
+ checkEvaluation(Length(Literal("a花花c")), 4, create_row(string))
+ // scalastyle:on
+ checkEvaluation(Length(Literal(bytes)), 5, create_row(Array[Byte]()))
+
+ checkEvaluation(Length(a), 5, create_row(string))
+ checkEvaluation(Length(b), 5, create_row(bytes))
+
+ checkEvaluation(Length(a), 0, create_row(""))
+ checkEvaluation(Length(b), 0, create_row(Array[Byte]()))
+
+ checkEvaluation(Length(a), null, create_row(null))
+ checkEvaluation(Length(b), null, create_row(null))
+
+ checkEvaluation(Length(Literal.create(null, StringType)), null, create_row(string))
+ checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes))
+ }
+
+ test("number format") {
+ checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal(3)), "4.000")
+ checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal(3)), "4.000")
+ checkEvaluation(FormatNumber(Literal(4.0f), Literal(3)), "4.000")
+ checkEvaluation(FormatNumber(Literal(4), Literal(3)), "4.000")
+ checkEvaluation(FormatNumber(Literal(12831273.23481d), Literal(3)), "12,831,273.235")
+ checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal(0)), "12,831,274")
+ checkEvaluation(FormatNumber(Literal(123123324123L), Literal(3)), "123,123,324,123.000")
+ checkEvaluation(FormatNumber(Literal(123123324123L), Literal(-1)), null)
+ checkEvaluation(
+ FormatNumber(
+ Literal(Decimal(123123324123L) * Decimal(123123.21234d)), Literal(4)),
+ "15,159,339,180,002,773.2778")
+ checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null)
+ checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
index 030bb6d21b18b..1d297beb3868d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
@@ -24,14 +24,14 @@ import org.scalatest.PrivateMethodTester
import scala.language.postfixOps
class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
- test("creating decimals") {
- /** Check that a Decimal has the given string representation, precision and scale */
- def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
- assert(d.toString === string)
- assert(d.precision === precision)
- assert(d.scale === scale)
- }
+ /** Check that a Decimal has the given string representation, precision and scale */
+ private def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
+ assert(d.toString === string)
+ assert(d.precision === precision)
+ assert(d.scale === scale)
+ }
+ test("creating decimals") {
checkDecimal(new Decimal(), "0", 1, 0)
checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3)
checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1)
@@ -53,6 +53,15 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0))
}
+ test("creating decimals with negative scale") {
+ checkDecimal(Decimal(BigDecimal("98765"), 5, -3), "9.9E+4", 5, -3)
+ checkDecimal(Decimal(BigDecimal("314.159"), 6, -2), "3E+2", 6, -2)
+ checkDecimal(Decimal(BigDecimal(1.579e12), 4, -9), "1.579E+12", 4, -9)
+ checkDecimal(Decimal(BigDecimal(1.579e12), 4, -10), "1.58E+12", 4, -10)
+ checkDecimal(Decimal(103050709L, 9, -10), "1.03050709E+18", 9, -10)
+ checkDecimal(Decimal(1e8.toLong, 10, -10), "1.00000000E+18", 10, -10)
+ }
+
test("double and long values") {
/** Check that a Decimal converts to the given double and long values */
def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = {
@@ -162,22 +171,4 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L)
assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue)
}
-
- test("accurate precision after multiplication") {
- val decimal = (Decimal(Long.MaxValue, 38, 0) * Decimal(Long.MaxValue, 38, 0)).toJavaBigDecimal
- assert(decimal.unscaledValue.toString === "85070591730234615847396907784232501249")
- }
-
- test("fix non-terminating decimal expansion problem") {
- val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3)
- // The difference between decimal should not be more than 0.001.
- assert(decimal.toDouble - 0.333 < 0.001)
- }
-
- test("fix loss of precision/scale when doing division operation") {
- val a = Decimal(2) / Decimal(3)
- assert(a.toDouble < 1.0 && a.toDouble > 0.6)
- val b = Decimal(1) / Decimal(8)
- assert(b.toDouble === 0.125)
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 43b62f0e822f8..92861ab038f19 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -47,6 +47,7 @@ private[r] object SQLUtils {
dataType match {
case "byte" => org.apache.spark.sql.types.ByteType
case "integer" => org.apache.spark.sql.types.IntegerType
+ case "float" => org.apache.spark.sql.types.FloatType
case "double" => org.apache.spark.sql.types.DoubleType
case "numeric" => org.apache.spark.sql.types.DoubleType
case "character" => org.apache.spark.sql.types.StringType
@@ -68,7 +69,7 @@ private[r] object SQLUtils {
def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = {
val num = schema.fields.size
- val rowRDD = rdd.map(bytesToRow)
+ val rowRDD = rdd.map(bytesToRow(_, schema))
sqlContext.createDataFrame(rowRDD, schema)
}
@@ -76,12 +77,20 @@ private[r] object SQLUtils {
df.map(r => rowToRBytes(r))
}
- private[this] def bytesToRow(bytes: Array[Byte]): Row = {
+ private[this] def doConversion(data: Object, dataType: DataType): Object = {
+ data match {
+ case d: java.lang.Double if dataType == FloatType =>
+ new java.lang.Float(d)
+ case _ => data
+ }
+ }
+
+ private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = {
val bis = new ByteArrayInputStream(bytes)
val dis = new DataInputStream(bis)
val num = SerDe.readInt(dis)
Row.fromSeq((0 until num).map { i =>
- SerDe.readObject(dis)
+ doConversion(SerDe.readObject(dis), schema.fields(i).dataType)
}.toSeq)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 4d7d8626a0ecc..9dc7879fa4a1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -247,6 +247,11 @@ private[sql] trait UnaryNode extends SparkPlan with trees.UnaryNode[SparkPlan] {
override def outputPartitioning: Partitioning = child.outputPartitioning
}
-private[sql] trait BinaryNode extends SparkPlan with trees.BinaryNode[SparkPlan] {
+private[sql] trait BinaryNode extends SparkPlan {
self: Product =>
+
+ def left: SparkPlan
+ def right: SparkPlan
+
+ override def children: Seq[SparkPlan] = Seq(left, right)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
index 437d143e53f3f..fec403fe2d348 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.expressions
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.LeafExpression
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types.{LongType, DataType}
/**
@@ -40,6 +41,10 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression {
*/
@transient private[this] var count: Long = 0L
+ @transient private lazy val partitionMask = TaskContext.getPartitionId().toLong << 33
+
+ override def deterministic: Boolean = false
+
override def nullable: Boolean = false
override def dataType: DataType = LongType
@@ -47,6 +52,20 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression {
override def eval(input: InternalRow): Long = {
val currentCount = count
count += 1
- (TaskContext.get().partitionId().toLong << 33) + currentCount
+ partitionMask + currentCount
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val countTerm = ctx.freshName("count")
+ val partitionMaskTerm = ctx.freshName("partitionMask")
+ ctx.addMutableState(ctx.JAVA_LONG, countTerm, "0L")
+ ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm,
+ "((long) org.apache.spark.TaskContext.getPartitionId()) << 33")
+
+ ev.isNull = "false"
+ s"""
+ final ${ctx.javaType(dataType)} ${ev.primitive} = $partitionMaskTerm + $countTerm;
+ $countTerm++;
+ """
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
index 822d3d8c9108d..7c790c549a5d8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.expressions
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.LeafExpression
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types.{IntegerType, DataType}
@@ -28,9 +29,20 @@ import org.apache.spark.sql.types.{IntegerType, DataType}
*/
private[sql] case object SparkPartitionID extends LeafExpression {
+ override def deterministic: Boolean = false
+
override def nullable: Boolean = false
override def dataType: DataType = IntegerType
- override def eval(input: InternalRow): Int = TaskContext.get().partitionId()
+ @transient private lazy val partitionId = TaskContext.getPartitionId()
+
+ override def eval(input: InternalRow): Int = partitionId
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val idTerm = ctx.freshName("partitionId")
+ ctx.addMutableState(ctx.JAVA_INT, idTerm, "org.apache.spark.TaskContext.getPartitionId()")
+ ev.isNull = "false"
+ s"final ${ctx.javaType(dataType)} ${ev.primitive} = $idTerm;"
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 0d4e160ed8057..d6da284a4c788 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1371,6 +1371,23 @@ object functions {
*/
def pow(l: Double, rightName: String): Column = pow(l, Column(rightName))
+ /**
+ * Returns the positive value of dividend mod divisor.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def pmod(dividend: Column, divisor: Column): Column = Pmod(dividend.expr, divisor.expr)
+
+ /**
+ * Returns the positive value of dividend mod divisor.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def pmod(dividendColName: String, divisorColName: String): Column =
+ pmod(Column(dividendColName), Column(divisorColName))
+
/**
* Returns the double value that is closest in value to the argument and
* is equal to a mathematical integer.
@@ -1389,6 +1406,38 @@ object functions {
*/
def rint(columnName: String): Column = rint(Column(columnName))
+ /**
+ * Returns the value of the column `e` rounded to 0 decimal places.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def round(e: Column): Column = round(e.expr, 0)
+
+ /**
+ * Returns the value of the given column rounded to 0 decimal places.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def round(columnName: String): Column = round(Column(columnName), 0)
+
+ /**
+ * Returns the value of `e` rounded to `scale` decimal places.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale))
+
+ /**
+ * Returns the value of the given column rounded to `scale` decimal places.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def round(columnName: String, scale: Int): Column = round(Column(columnName), scale)
+
/**
* Shift the the given value numBits left. If the given value is a long value, this function
* will return a long value else it will return an integer value.
@@ -1636,20 +1685,44 @@ object functions {
//////////////////////////////////////////////////////////////////////////////////////////////
/**
- * Computes the length of a given string value.
+ * Computes the length of a given string / binary value.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def length(e: Column): Column = Length(e.expr)
+
+ /**
+ * Computes the length of a given string / binary column.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def length(columnName: String): Column = length(Column(columnName))
+
+ /**
+ * Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
+ * and returns the result as a string.
+ * If d is 0, the result has no decimal point or fractional part.
+ * If d < 0, the result will be null.
*
* @group string_funcs
* @since 1.5.0
*/
- def strlen(e: Column): Column = StringLength(e.expr)
+ def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr)
/**
- * Computes the length of a given string column.
+ * Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
+ * and returns the result as a string.
+ * If d is 0, the result has no decimal point or fractional part.
+ * If d < 0, the result will be null.
*
* @group string_funcs
* @since 1.5.0
*/
- def strlen(columnName: String): Column = strlen(Column(columnName))
+ def format_number(columnXName: String, d: Int): Column = {
+ format_number(Column(columnXName), d)
+ }
/**
* Computes the Levenshtein distance of the two given strings.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 6cebec95d2850..6dccdd857b453 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -208,17 +208,6 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(2743272264L, 2180413220L))
}
- test("string length function") {
- val df = Seq(("abc", "")).toDF("a", "b")
- checkAnswer(
- df.select(strlen($"a"), strlen("b")),
- Row(3, 0))
-
- checkAnswer(
- df.selectExpr("length(a)", "length(b)"),
- Row(3, 0))
- }
-
test("Levenshtein distance") {
val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1)))
@@ -403,4 +392,121 @@ class DataFrameFunctionsSuite extends QueryTest {
Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3))
)
}
+
+ test("pmod") {
+ val intData = Seq((7, 3), (-7, 3)).toDF("a", "b")
+ checkAnswer(
+ intData.select(pmod('a, 'b)),
+ Seq(Row(1), Row(2))
+ )
+ checkAnswer(
+ intData.select(pmod('a, lit(3))),
+ Seq(Row(1), Row(2))
+ )
+ checkAnswer(
+ intData.select(pmod(lit(-7), 'b)),
+ Seq(Row(2), Row(2))
+ )
+ checkAnswer(
+ intData.selectExpr("pmod(a, b)"),
+ Seq(Row(1), Row(2))
+ )
+ checkAnswer(
+ intData.selectExpr("pmod(a, 3)"),
+ Seq(Row(1), Row(2))
+ )
+ checkAnswer(
+ intData.selectExpr("pmod(-7, b)"),
+ Seq(Row(2), Row(2))
+ )
+ val doubleData = Seq((7.2, 4.1)).toDF("a", "b")
+ checkAnswer(
+ doubleData.select(pmod('a, 'b)),
+ Seq(Row(3.1000000000000005)) // same as hive
+ )
+ checkAnswer(
+ doubleData.select(pmod(lit(2), lit(Int.MaxValue))),
+ Seq(Row(2))
+ )
+ }
+
+ test("string / binary length function") {
+ val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c")
+ checkAnswer(
+ df.select(length($"a"), length("a"), length($"b"), length("b")),
+ Row(3, 3, 4, 4))
+
+ checkAnswer(
+ df.selectExpr("length(a)", "length(b)"),
+ Row(3, 4))
+
+ intercept[AnalysisException] {
+ checkAnswer(
+ df.selectExpr("length(c)"), // int type of the argument is unacceptable
+ Row("5.0000"))
+ }
+ }
+
+ test("number format function") {
+ val tuple =
+ ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short],
+ 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381))
+ val df =
+ Seq(tuple)
+ .toDF(
+ "a", // string "aa"
+ "b", // byte 1
+ "c", // short 2
+ "d", // float 3.13223f
+ "e", // integer 4
+ "f", // long 5L
+ "g", // double 6.48173d
+ "h") // decimal 7.128381
+
+ checkAnswer(
+ df.select(
+ format_number($"f", 4),
+ format_number("f", 4)),
+ Row("5.0000", "5.0000"))
+
+ checkAnswer(
+ df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer
+ Row("1.0000"))
+
+ checkAnswer(
+ df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer
+ Row("2.0000"))
+
+ checkAnswer(
+ df.selectExpr("format_number(d, e)"), // convert the 1st argument to double
+ Row("3.1322"))
+
+ checkAnswer(
+ df.selectExpr("format_number(e, e)"), // not convert anything
+ Row("4.0000"))
+
+ checkAnswer(
+ df.selectExpr("format_number(f, e)"), // not convert anything
+ Row("5.0000"))
+
+ checkAnswer(
+ df.selectExpr("format_number(g, e)"), // not convert anything
+ Row("6.4817"))
+
+ checkAnswer(
+ df.selectExpr("format_number(h, e)"), // not convert anything
+ Row("7.1284"))
+
+ intercept[AnalysisException] {
+ checkAnswer(
+ df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable
+ Row("5.0000"))
+ }
+
+ intercept[AnalysisException] {
+ checkAnswer(
+ df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable
+ Row("5.0000"))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index 24bef21b999ea..087126bb2e513 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -198,6 +198,21 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneMathFunction(rint, math.rint)
}
+ test("round") {
+ val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a")
+ checkAnswer(
+ df.select(round('a), round('a, -1), round('a, -2)),
+ Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600))
+ )
+
+ val pi = 3.1415
+ checkAnswer(
+ ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
+ s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"),
+ Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142))
+ )
+ }
+
test("exp") {
testOneToOneMathFunction(exp, math.exp)
}
@@ -375,6 +390,5 @@ class MathExpressionsSuite extends QueryTest {
val df = Seq((1, -1, "abc")).toDF("a", "b", "c")
checkAnswer(df.selectExpr("positive(a)"), Row(1))
checkAnswer(df.selectExpr("positive(b)"), Row(-1))
- checkAnswer(df.selectExpr("positive(c)"), Row("abc"))
}
}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index c884c399281a8..4ada64bc21966 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -221,9 +221,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_when",
"udf_case",
- // Needs constant object inspectors
- "udf_round",
-
// the table src(key INT, value STRING) is not the same as HIVE unittest. In Hive
// is src(key STRING, value STRING), and in the reflect.q, it failed in
// Integer.valueOf, which expect the first argument passed as STRING type not INT.
@@ -918,8 +915,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_regexp_replace",
"udf_repeat",
"udf_rlike",
- "udf_round",
- // "udf_round_3", TODO: FIX THIS failed due to cast exception
+ // "udf_round", turn this on after we figure out null vs nan vs infinity
+ "udf_round_3",
"udf_rpad",
"udf_rtrim",
"udf_second",
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 5bdf68c83fca7..4b7a782c805a0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -301,9 +301,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
val result = if (metastoreRelation.hiveQlTable.isPartitioned) {
val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys)
val partitionColumnDataTypes = partitionSchema.map(_.dataType)
- // We're converting the entire table into ParquetRelation, so predicates to Hive metastore
- // are empty.
- val partitions = metastoreRelation.getHiveQlPartitions().map { p =>
+ val partitions = metastoreRelation.hiveQlPartitions.map { p =>
val location = p.getLocation
val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map {
case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null)
@@ -646,6 +644,32 @@ private[hive] case class MetastoreRelation
new Table(tTable)
}
+ @transient val hiveQlPartitions: Seq[Partition] = table.getAllPartitions.map { p =>
+ val tPartition = new org.apache.hadoop.hive.metastore.api.Partition
+ tPartition.setDbName(databaseName)
+ tPartition.setTableName(tableName)
+ tPartition.setValues(p.values)
+
+ val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor()
+ tPartition.setSd(sd)
+ sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)))
+
+ sd.setLocation(p.storage.location)
+ sd.setInputFormat(p.storage.inputFormat)
+ sd.setOutputFormat(p.storage.outputFormat)
+
+ val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo
+ sd.setSerdeInfo(serdeInfo)
+ serdeInfo.setSerializationLib(p.storage.serde)
+
+ val serdeParameters = new java.util.HashMap[String, String]()
+ serdeInfo.setParameters(serdeParameters)
+ table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
+ p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
+
+ new Partition(hiveQlTable, tPartition)
+ }
+
@transient override lazy val statistics: Statistics = Statistics(
sizeInBytes = {
val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE)
@@ -666,34 +690,6 @@ private[hive] case class MetastoreRelation
}
)
- def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = {
- table.getPartitions(predicates).map { p =>
- val tPartition = new org.apache.hadoop.hive.metastore.api.Partition
- tPartition.setDbName(databaseName)
- tPartition.setTableName(tableName)
- tPartition.setValues(p.values)
-
- val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor()
- tPartition.setSd(sd)
- sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)))
-
- sd.setLocation(p.storage.location)
- sd.setInputFormat(p.storage.inputFormat)
- sd.setOutputFormat(p.storage.outputFormat)
-
- val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo
- sd.setSerdeInfo(serdeInfo)
- serdeInfo.setSerializationLib(p.storage.serde)
-
- val serdeParameters = new java.util.HashMap[String, String]()
- serdeInfo.setParameters(serdeParameters)
- table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
- p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
-
- new Partition(hiveQlTable, tPartition)
- }
- }
-
/** Only compare database and tablename, not alias. */
override def sameResult(plan: LogicalPlan): Boolean = {
plan match {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
index a357bb39ca7fd..d08c594151654 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
@@ -27,7 +27,6 @@ import scala.reflect.ClassTag
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{Input, Output}
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 9638a8201e190..ed359620a5f7f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -125,7 +125,7 @@ private[hive] trait HiveStrategies {
InterpretedPredicate.create(castedPredicate)
}
- val partitions = relation.getHiveQlPartitions(pruningPredicates).filter { part =>
+ val partitions = relation.hiveQlPartitions.filter { part =>
val partitionValues = part.getValues
var i = 0
while (i < partitionValues.size()) {
@@ -213,7 +213,7 @@ private[hive] trait HiveStrategies {
projectList,
otherPredicates,
identity[Seq[Expression]],
- HiveTableScan(_, relation, pruningPredicates)(hiveContext)) :: Nil
+ HiveTableScan(_, relation, pruningPredicates.reduceLeftOption(And))(hiveContext)) :: Nil
case _ =>
Nil
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
index 1656587d14835..0a1d761a52f88 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
@@ -21,7 +21,6 @@ import java.io.PrintStream
import java.util.{Map => JMap}
import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException}
-import org.apache.spark.sql.catalyst.expressions.Expression
private[hive] case class HiveDatabase(
name: String,
@@ -72,12 +71,7 @@ private[hive] case class HiveTable(
def isPartitioned: Boolean = partitionColumns.nonEmpty
- def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = {
- predicates match {
- case Nil => client.getAllPartitions(this)
- case _ => client.getPartitionsByFilter(this, predicates)
- }
- }
+ def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this)
// Hive does not support backticks when passing names to the client.
def qualifiedName: String = s"$database.$name"
@@ -138,9 +132,6 @@ private[hive] trait ClientInterface {
/** Returns all partitions for the given table. */
def getAllPartitions(hTable: HiveTable): Seq[HivePartition]
- /** Returns partitions filtered by predicates for the given table. */
- def getPartitionsByFilter(hTable: HiveTable, predicates: Seq[Expression]): Seq[HivePartition]
-
/** Loads a static partition into an existing table. */
def loadPartition(
loadPath: String,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
index 8adda54754230..53f457ad4f3cc 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
@@ -17,21 +17,25 @@
package org.apache.spark.sql.hive.client
-import java.io.{File, PrintStream}
-import java.util.{Map => JMap}
+import java.io.{BufferedReader, InputStreamReader, File, PrintStream}
+import java.net.URI
+import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet}
import javax.annotation.concurrent.GuardedBy
import scala.collection.JavaConversions._
import scala.language.reflectiveCalls
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.hive.metastore.api.Database
import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema}
import org.apache.hadoop.hive.metastore.{TableType => HTableType}
+import org.apache.hadoop.hive.metastore.api
+import org.apache.hadoop.hive.metastore.api.FieldSchema
+import org.apache.hadoop.hive.ql.metadata
import org.apache.hadoop.hive.ql.metadata.Hive
-import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.ql.session.SessionState
-import org.apache.hadoop.hive.ql.{Driver, metadata}
+import org.apache.hadoop.hive.ql.processors._
+import org.apache.hadoop.hive.ql.Driver
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.Expression
@@ -312,13 +316,6 @@ private[hive] class ClientWrapper(
shim.getAllPartitions(client, qlTable).map(toHivePartition)
}
- override def getPartitionsByFilter(
- hTable: HiveTable,
- predicates: Seq[Expression]): Seq[HivePartition] = withHiveState {
- val qlTable = toQlTable(hTable)
- shim.getPartitionsByFilter(client, qlTable, predicates).map(toHivePartition)
- }
-
override def listTables(dbName: String): Seq[String] = withHiveState {
client.getAllTables(dbName)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
index d12778c7583df..1fa9d278e2a57 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
@@ -31,11 +31,6 @@ import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory}
import org.apache.hadoop.hive.ql.session.SessionState
-import org.apache.hadoop.hive.serde.serdeConstants
-
-import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{StringType, IntegralType}
/**
* A shim that defines the interface between ClientWrapper and the underlying Hive library used to
@@ -66,8 +61,6 @@ private[client] sealed abstract class Shim {
def getAllPartitions(hive: Hive, table: Table): Seq[Partition]
- def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition]
-
def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor
def getDriverResults(driver: Driver): Seq[String]
@@ -116,7 +109,7 @@ private[client] sealed abstract class Shim {
}
-private[client] class Shim_v0_12 extends Shim with Logging {
+private[client] class Shim_v0_12 extends Shim {
private lazy val startMethod =
findStaticMethod(
@@ -203,17 +196,6 @@ private[client] class Shim_v0_12 extends Shim with Logging {
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq
- override def getPartitionsByFilter(
- hive: Hive,
- table: Table,
- predicates: Seq[Expression]): Seq[Partition] = {
- // getPartitionsByFilter() doesn't support binary comparison ops in Hive 0.12.
- // See HIVE-4888.
- logDebug("Hive 0.12 doesn't support predicate pushdown to metastore. " +
- "Please use Hive 0.13 or higher.")
- getAllPartitions(hive, table)
- }
-
override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor]
@@ -285,12 +267,6 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
classOf[Hive],
"getAllPartitionsOf",
classOf[Table])
- private lazy val getPartitionsByFilterMethod =
- findMethod(
- classOf[Hive],
- "getPartitionsByFilter",
- classOf[Table],
- classOf[String])
private lazy val getCommandProcessorMethod =
findStaticMethod(
classOf[CommandProcessorFactory],
@@ -312,52 +288,6 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq
- /**
- * Converts catalyst expression to the format that Hive's getPartitionsByFilter() expects, i.e.
- * a string that represents partition predicates like "str_key=\"value\" and int_key=1 ...".
- *
- * Unsupported predicates are skipped.
- */
- def convertFilters(table: Table, filters: Seq[Expression]): String = {
- // hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
- val varcharKeys = table.getPartitionKeys
- .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME))
- .map(col => col.getName).toSet
-
- filters.collect {
- case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) =>
- s"${a.name} ${op.symbol} $v"
- case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) =>
- s"$v ${op.symbol} ${a.name}"
-
- case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType))
- if !varcharKeys.contains(a.name) =>
- s"""${a.name} ${op.symbol} "$v""""
- case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute)
- if !varcharKeys.contains(a.name) =>
- s""""$v" ${op.symbol} ${a.name}"""
- }.mkString(" and ")
- }
-
- override def getPartitionsByFilter(
- hive: Hive,
- table: Table,
- predicates: Seq[Expression]): Seq[Partition] = {
-
- // Hive getPartitionsByFilter() takes a string that represents partition
- // predicates like "str_key=\"value\" and int_key=1 ..."
- val filter = convertFilters(table, predicates)
- val partitions =
- if (filter.isEmpty) {
- getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]]
- } else {
- logDebug(s"Hive metastore filter is '$filter'.")
- getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]]
- }
-
- partitions.toSeq
- }
-
override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor]
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
index ba7eb15a1c0c6..d33da8242cc1d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
@@ -44,7 +44,7 @@ private[hive]
case class HiveTableScan(
requestedAttributes: Seq[Attribute],
relation: MetastoreRelation,
- partitionPruningPred: Seq[Expression])(
+ partitionPruningPred: Option[Expression])(
@transient val context: HiveContext)
extends LeafNode {
@@ -56,7 +56,7 @@ case class HiveTableScan(
// Bind all partition key attribute references in the partition pruning predicate for later
// evaluation.
- private[this] val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred =>
+ private[this] val boundPruningPred = partitionPruningPred.map { pred =>
require(
pred.dataType == BooleanType,
s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.")
@@ -133,8 +133,7 @@ case class HiveTableScan(
protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) {
hadoopReader.makeRDDForTable(relation.hiveQlTable)
} else {
- hadoopReader.makeRDDForPartitionedTable(
- prunePartitions(relation.getHiveQlPartitions(partitionPruningPred)))
+ hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions))
}
override def output: Seq[Attribute] = attributes
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java
index c4828c4717643..741a3cd31c603 100644
--- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java
@@ -61,7 +61,9 @@ public void setUp() throws IOException {
@After
public void tearDown() throws IOException {
// Clean up tables.
- hc.sql("DROP TABLE IF EXISTS window_table");
+ if (hc != null) {
+ hc.sql("DROP TABLE IF EXISTS window_table");
+ }
}
@Test
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala
deleted file mode 100644
index 0efcf80bd4ea7..0000000000000
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.client
-
-import scala.collection.JavaConversions._
-
-import org.apache.hadoop.hive.metastore.api.FieldSchema
-import org.apache.hadoop.hive.serde.serdeConstants
-
-import org.apache.spark.{Logging, SparkFunSuite}
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types._
-
-/**
- * A set of tests for the filter conversion logic used when pushing partition pruning into the
- * metastore
- */
-class FiltersSuite extends SparkFunSuite with Logging {
- private val shim = new Shim_v0_13
-
- private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test")
- private val varCharCol = new FieldSchema()
- varCharCol.setName("varchar")
- varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME)
- testTable.setPartCols(varCharCol :: Nil)
-
- filterTest("string filter",
- (a("stringcol", StringType) > Literal("test")) :: Nil,
- "stringcol > \"test\"")
-
- filterTest("string filter backwards",
- (Literal("test") > a("stringcol", StringType)) :: Nil,
- "\"test\" > stringcol")
-
- filterTest("int filter",
- (a("intcol", IntegerType) === Literal(1)) :: Nil,
- "intcol = 1")
-
- filterTest("int filter backwards",
- (Literal(1) === a("intcol", IntegerType)) :: Nil,
- "1 = intcol")
-
- filterTest("int and string filter",
- (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil,
- "1 = intcol and \"a\" = strcol")
-
- filterTest("skip varchar",
- (Literal("") === a("varchar", StringType)) :: Nil,
- "")
-
- private def filterTest(name: String, filters: Seq[Expression], result: String) = {
- test(name){
- val converted = shim.convertFilters(testTable, filters)
- if (converted != result) {
- fail(
- s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'")
- }
- }
- }
-
- private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)()
-}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
index 3eb127e23d486..d52e162acbd04 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
@@ -20,9 +20,7 @@ package org.apache.spark.sql.hive.client
import java.io.File
import org.apache.spark.{Logging, SparkFunSuite}
-import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo}
import org.apache.spark.sql.catalyst.util.quietly
-import org.apache.spark.sql.types.IntegerType
import org.apache.spark.util.Utils
/**
@@ -153,12 +151,6 @@ class VersionsSuite extends SparkFunSuite with Logging {
client.getAllPartitions(client.getTable("default", "src_part"))
}
- test(s"$version: getPartitionsByFilter") {
- client.getPartitionsByFilter(client.getTable("default", "src_part"), Seq(EqualTo(
- AttributeReference("key", IntegerType, false)(NamedExpression.newExprId),
- Literal(1))))
- }
-
test(s"$version: loadPartition") {
client.loadPartition(
emptyDir,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
index e83a7dc77e329..de6a41ce5bfcb 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
@@ -151,7 +151,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter {
case p @ HiveTableScan(columns, relation, _) =>
val columnNames = columns.map(_.name)
val partValues = if (relation.table.isPartitioned) {
- p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues)
+ p.prunePartitions(relation.hiveQlPartitions).map(_.getValues)
} else {
Seq.empty
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java
index eb7475e9df869..905ea0b7b878c 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java
@@ -62,6 +62,7 @@ public static Interval fromString(String s) {
if (s == null) {
return null;
}
+ s = s.trim();
Matcher m = p.matcher(s);
if (!m.matches() || s.equals("interval")) {
return null;
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java
index 44a949a371f2b..1832d0bc65551 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java
@@ -75,6 +75,12 @@ public void fromStringTest() {
Interval result = new Interval(-5 * 12 + 23, 0);
assertEquals(Interval.fromString(input), result);
+ input = "interval -5 years 23 month ";
+ assertEquals(Interval.fromString(input), result);
+
+ input = " interval -5 years 23 month ";
+ assertEquals(Interval.fromString(input), result);
+
// Error cases
input = "interval 3month 1 hour";
assertEquals(Interval.fromString(input), null);