...
@@ -54,7 +153,7 @@ object MLUtils {
sc.textFile(dir).map { line =>
val parts = line.split(',')
val label = parts(0).toDouble
- val features = parts(1).trim().split(' ').map(_.toDouble)
+ val features = Vectors.dense(parts(1).trim().split(' ').map(_.toDouble))
LabeledPoint(label, features)
}
}
@@ -68,7 +167,7 @@ object MLUtils {
* @param dir Directory to save the data.
*/
def saveLabeledData(data: RDD[LabeledPoint], dir: String) {
- val dataStr = data.map(x => x.label + "," + x.features.mkString(" "))
+ val dataStr = data.map(x => x.label + "," + x.features.toArray.mkString(" "))
dataStr.saveAsTextFile(dir)
}
@@ -76,44 +175,52 @@ object MLUtils {
* Utility function to compute mean and standard deviation on a given dataset.
*
* @param data - input data set whose statistics are computed
- * @param nfeatures - number of features
- * @param nexamples - number of examples in input dataset
+ * @param numFeatures - number of features
+ * @param numExamples - number of examples in input dataset
*
* @return (yMean, xColMean, xColSd) - Tuple consisting of
* yMean - mean of the labels
* xColMean - Row vector with mean for every column (or feature) of the input data
* xColSd - Row vector standard deviation for every column (or feature) of the input data.
*/
- def computeStats(data: RDD[LabeledPoint], nfeatures: Int, nexamples: Long):
- (Double, DoubleMatrix, DoubleMatrix) = {
- val yMean: Double = data.map { labeledPoint => labeledPoint.label }.reduce(_ + _) / nexamples
-
- // NOTE: We shuffle X by column here to compute column sum and sum of squares.
- val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { labeledPoint =>
- val nCols = labeledPoint.features.length
- // Traverse over every column and emit (col, value, value^2)
- Iterator.tabulate(nCols) { i =>
- (i, (labeledPoint.features(i), labeledPoint.features(i)*labeledPoint.features(i)))
- }
- }.reduceByKey { case(x1, x2) =>
- (x1._1 + x2._1, x1._2 + x2._2)
+ def computeStats(
+ data: RDD[LabeledPoint],
+ numFeatures: Int,
+ numExamples: Long): (Double, Vector, Vector) = {
+ val brzData = data.map { case LabeledPoint(label, features) =>
+ (label, features.toBreeze)
}
- val xColSumsMap = xColSumSq.collectAsMap()
-
- val xColMean = DoubleMatrix.zeros(nfeatures, 1)
- val xColSd = DoubleMatrix.zeros(nfeatures, 1)
-
- // Compute mean and unbiased variance using column sums
- var col = 0
- while (col < nfeatures) {
- xColMean.put(col, xColSumsMap(col)._1 / nexamples)
- val variance =
- (xColSumsMap(col)._2 - (math.pow(xColSumsMap(col)._1, 2) / nexamples)) / nexamples
- xColSd.put(col, math.sqrt(variance))
- col += 1
+ val aggStats = brzData.aggregate(
+ (0L, 0.0, BDV.zeros[Double](numFeatures), BDV.zeros[Double](numFeatures))
+ )(
+ seqOp = (c, v) => (c, v) match {
+ case ((n, sumLabel, sum, sumSq), (label, features)) =>
+ features.activeIterator.foreach { case (i, x) =>
+ sumSq(i) += x * x
+ }
+ (n + 1L, sumLabel + label, sum += features, sumSq)
+ },
+ combOp = (c1, c2) => (c1, c2) match {
+ case ((n1, sumLabel1, sum1, sumSq1), (n2, sumLabel2, sum2, sumSq2)) =>
+ (n1 + n2, sumLabel1 + sumLabel2, sum1 += sum2, sumSq1 += sumSq2)
+ }
+ )
+ val (nl, sumLabel, sum, sumSq) = aggStats
+
+ require(nl > 0, "Input data is empty.")
+ require(nl == numExamples)
+
+ val n = nl.toDouble
+ val yMean = sumLabel / n
+ val mean = sum / n
+ val std = new Array[Double](sum.length)
+ var i = 0
+ while (i < numFeatures) {
+ std(i) = sumSq(i) / n - mean(i) * mean(i)
+ i += 1
}
- (yMean, xColMean, xColSd)
+ (yMean, Vectors.fromBreeze(mean), Vectors.dense(std))
}
/**
@@ -144,6 +251,18 @@ object MLUtils {
val sumSquaredNorm = norm1 * norm1 + norm2 * norm2
val normDiff = norm1 - norm2
var sqDist = 0.0
+ /*
+ * The relative error is
+ *
+ * EPSILON * ( \|a\|_2^2 + \|b\\_2^2 + 2 |a^T b|) / ( \|a - b\|_2^2 ),
+ *
+ * which is bounded by
+ *
+ * 2.0 * EPSILON * ( \|a\|_2^2 + \|b\|_2^2 ) / ( (\|a\|_2 - \|b\|_2)^2 ).
+ *
+ * The bound doesn't need the inner product, so we can use it as a sufficient condition to
+ * check quickly whether the inner product approach is accurate.
+ */
val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON)
if (precisionBound1 < precision) {
sqDist = sumSquaredNorm - 2.0 * v1.dot(v2)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
index c96c94f70eef7..e300c3dbe1fe0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
@@ -23,6 +23,7 @@ import org.jblas.DoubleMatrix
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
/**
@@ -58,7 +59,7 @@ object SVMDataGenerator {
}
val yD = new DoubleMatrix(1, x.length, x: _*).dot(trueWeights) + rnd.nextGaussian() * 0.1
val y = if (yD < 0) 0.0 else 1.0
- LabeledPoint(y, x)
+ LabeledPoint(y, Vectors.dense(x))
}
MLUtils.saveLabeledData(data, outputPath)
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
index 073ded6f36933..c80b1134ed1b2 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -19,6 +19,7 @@
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.junit.After;
import org.junit.Assert;
@@ -45,12 +46,12 @@ public void tearDown() {
}
private static final List POINTS = Arrays.asList(
- new LabeledPoint(0, new double[] {1.0, 0.0, 0.0}),
- new LabeledPoint(0, new double[] {2.0, 0.0, 0.0}),
- new LabeledPoint(1, new double[] {0.0, 1.0, 0.0}),
- new LabeledPoint(1, new double[] {0.0, 2.0, 0.0}),
- new LabeledPoint(2, new double[] {0.0, 0.0, 1.0}),
- new LabeledPoint(2, new double[] {0.0, 0.0, 2.0})
+ new LabeledPoint(0, Vectors.dense(1.0, 0.0, 0.0)),
+ new LabeledPoint(0, Vectors.dense(2.0, 0.0, 0.0)),
+ new LabeledPoint(1, Vectors.dense(0.0, 1.0, 0.0)),
+ new LabeledPoint(1, Vectors.dense(0.0, 2.0, 0.0)),
+ new LabeledPoint(2, Vectors.dense(0.0, 0.0, 1.0)),
+ new LabeledPoint(2, Vectors.dense(0.0, 0.0, 2.0))
);
private int validatePrediction(List points, NaiveBayesModel model) {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
index 117e5eaa8b78e..4701a5e545020 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.mllib.classification;
-
import java.io.Serializable;
import java.util.List;
@@ -28,7 +27,6 @@
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-
import org.apache.spark.mllib.regression.LabeledPoint;
public class JavaSVMSuite implements Serializable {
@@ -94,5 +92,4 @@ public void runSVMUsingStaticMethods() {
int numAccurate = validatePrediction(validationData, model);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
}
-
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
index 2c4d795f96e4e..c6d8425ffc38d 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
@@ -19,10 +19,10 @@
import java.io.Serializable;
-import com.google.common.collect.Lists;
-
import scala.Tuple2;
+import com.google.common.collect.Lists;
+
import org.junit.Test;
import static org.junit.Assert.*;
@@ -36,7 +36,7 @@ public void denseArrayConstruction() {
@Test
public void sparseArrayConstruction() {
- Vector v = Vectors.sparse(3, Lists.newArrayList(
+ Vector v = Vectors.sparse(3, Lists.>newArrayList(
new Tuple2(0, 2.0),
new Tuple2(2, 3.0)));
assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
index f44b25cd44d19..f725924a2d971 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
@@ -59,7 +59,7 @@ int validatePrediction(List validationData, LassoModel model) {
@Test
public void runLassoUsingConstructor() {
int nPoints = 10000;
- double A = 2.0;
+ double A = 0.0;
double[] weights = {-1.5, 1.0e-2};
JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
@@ -80,7 +80,7 @@ public void runLassoUsingConstructor() {
@Test
public void runLassoUsingStaticMethods() {
int nPoints = 10000;
- double A = 2.0;
+ double A = 0.0;
double[] weights = {-1.5, 1.0e-2};
JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
index 2fdd5fc8fdca6..03714ae7e4d00 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
@@ -55,30 +55,27 @@ public void tearDown() {
return errorSum / validationData.size();
}
- List generateRidgeData(int numPoints, int nfeatures, double eps) {
+ List generateRidgeData(int numPoints, int numFeatures, double std) {
org.jblas.util.Random.seed(42);
// Pick weights as random values distributed uniformly in [-0.5, 0.5]
- DoubleMatrix w = DoubleMatrix.rand(nfeatures, 1).subi(0.5);
- // Set first two weights to eps
- w.put(0, 0, eps);
- w.put(1, 0, eps);
- return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, eps);
+ DoubleMatrix w = DoubleMatrix.rand(numFeatures, 1).subi(0.5);
+ return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, std);
}
@Test
public void runRidgeRegressionUsingConstructor() {
- int nexamples = 200;
- int nfeatures = 20;
- double eps = 10.0;
- List data = generateRidgeData(2*nexamples, nfeatures, eps);
+ int numExamples = 50;
+ int numFeatures = 20;
+ List data = generateRidgeData(2*numExamples, numFeatures, 10.0);
- JavaRDD testRDD = sc.parallelize(data.subList(0, nexamples));
- List validationData = data.subList(nexamples, 2*nexamples);
+ JavaRDD testRDD = sc.parallelize(data.subList(0, numExamples));
+ List validationData = data.subList(numExamples, 2 * numExamples);
RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
- ridgeSGDImpl.optimizer().setStepSize(1.0)
- .setRegParam(0.0)
- .setNumIterations(200);
+ ridgeSGDImpl.optimizer()
+ .setStepSize(1.0)
+ .setRegParam(0.0)
+ .setNumIterations(200);
RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd());
double unRegularizedErr = predictionError(validationData, model);
@@ -91,13 +88,12 @@ public void runRidgeRegressionUsingConstructor() {
@Test
public void runRidgeRegressionUsingStaticMethods() {
- int nexamples = 200;
- int nfeatures = 20;
- double eps = 10.0;
- List data = generateRidgeData(2*nexamples, nfeatures, eps);
+ int numExamples = 50;
+ int numFeatures = 20;
+ List data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
- JavaRDD testRDD = sc.parallelize(data.subList(0, nexamples));
- List validationData = data.subList(nexamples, 2*nexamples);
+ JavaRDD testRDD = sc.parallelize(data.subList(0, numExamples));
+ List validationData = data.subList(numExamples, 2 * numExamples);
RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0);
double unRegularizedErr = predictionError(validationData, model);
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 05322b024d5f6..1e03c9df820b0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -20,11 +20,10 @@ package org.apache.spark.mllib.classification
import scala.util.Random
import scala.collection.JavaConversions._
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
-import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.LocalSparkContext
@@ -61,7 +60,7 @@ object LogisticRegressionSuite {
if (yVal > 0) 1 else 0
}
- val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Array(x1(i))))
+ val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(Array(x1(i)))))
testData
}
@@ -113,7 +112,7 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Shoul
val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42)
val initialB = -1.0
- val initialWeights = Array(initialB)
+ val initialWeights = Vectors.dense(initialB)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 9dd6c79ee6ad8..516895d04222d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -19,9 +19,9 @@ package org.apache.spark.mllib.classification
import scala.util.Random
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.LocalSparkContext
@@ -54,7 +54,7 @@ object NaiveBayesSuite {
if (rnd.nextDouble() < _theta(y)(j)) 1 else 0
}
- LabeledPoint(y, xi)
+ LabeledPoint(y, Vectors.dense(xi))
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index bc7abb568a172..dfacbfeee6fb4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.mllib.classification
import scala.util.Random
import scala.collection.JavaConversions._
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.jblas.DoubleMatrix
@@ -28,6 +27,7 @@ import org.jblas.DoubleMatrix
import org.apache.spark.SparkException
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.linalg.Vectors
object SVMSuite {
@@ -54,7 +54,7 @@ object SVMSuite {
intercept + 0.01 * rnd.nextGaussian()
if (yD < 0) 0.0 else 1.0
}
- y.zip(x).map(p => LabeledPoint(p._1, p._2))
+ y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2)))
}
}
@@ -110,7 +110,7 @@ class SVMSuite extends FunSuite with LocalSparkContext {
val initialB = -1.0
val initialC = -1.0
- val initialWeights = Array(initialB,initialC)
+ val initialWeights = Vectors.dense(initialB, initialC)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
@@ -150,10 +150,10 @@ class SVMSuite extends FunSuite with LocalSparkContext {
}
intercept[SparkException] {
- val model = SVMWithSGD.train(testRDDInvalid, 100)
+ SVMWithSGD.train(testRDDInvalid, 100)
}
// Turning off data validation should not throw an exception
- val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
+ new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index 631d0e2ad9cdb..c4b433499a091 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -20,13 +20,12 @@ package org.apache.spark.mllib.optimization
import scala.util.Random
import scala.collection.JavaConversions._
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
-import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.linalg.Vectors
object GradientDescentSuite {
@@ -58,8 +57,7 @@ object GradientDescentSuite {
if (yVal > 0) 1 else 0
}
- val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Array(x1(i))))
- testData
+ (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(x1(i))))
}
}
@@ -83,11 +81,11 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa
// Add a extra variable consisting of all 1.0's for the intercept.
val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42)
val data = testData.map { case LabeledPoint(label, features) =>
- label -> Array(1.0, features: _*)
+ label -> Vectors.dense(1.0, features.toArray: _*)
}
val dataRDD = sc.parallelize(data, 2).cache()
- val initialWeightsWithIntercept = Array(1.0, initialWeights: _*)
+ val initialWeightsWithIntercept = Vectors.dense(1.0, initialWeights: _*)
val (_, loss) = GradientDescent.runMiniBatchSGD(
dataRDD,
@@ -113,13 +111,13 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa
// Add a extra variable consisting of all 1.0's for the intercept.
val testData = GradientDescentSuite.generateGDInput(2.0, -1.5, 10000, 42)
val data = testData.map { case LabeledPoint(label, features) =>
- label -> Array(1.0, features: _*)
+ label -> Vectors.dense(1.0, features.toArray: _*)
}
val dataRDD = sc.parallelize(data, 2).cache()
// Prepare non-zero weights
- val initialWeightsWithIntercept = Array(1.0, 0.5)
+ val initialWeightsWithIntercept = Vectors.dense(1.0, 0.5)
val regParam0 = 0
val (newWeights0, loss0) = GradientDescent.runMiniBatchSGD(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index 2cebac943e15f..6aad9eb84e13c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.regression
import org.scalatest.FunSuite
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
class LassoSuite extends FunSuite with LocalSparkContext {
@@ -33,29 +34,33 @@ class LassoSuite extends FunSuite with LocalSparkContext {
}
test("Lasso local random SGD") {
- val nPoints = 10000
+ val nPoints = 1000
val A = 2.0
val B = -1.5
val C = 1.0e-2
- val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42)
-
- val testRDD = sc.parallelize(testData, 2)
- testRDD.cache()
+ val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B, C), nPoints, 42)
+ .map { case LabeledPoint(label, features) =>
+ LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
+ }
+ val testRDD = sc.parallelize(testData, 2).cache()
val ls = new LassoWithSGD()
- ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
+ ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40)
val model = ls.run(testRDD)
-
val weight0 = model.weights(0)
val weight1 = model.weights(1)
- assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
- assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
- assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
+ val weight2 = model.weights(2)
+ assert(weight0 >= 1.9 && weight0 <= 2.1, weight0 + " not in [1.9, 2.1]")
+ assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]")
+ assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]")
val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17)
+ .map { case LabeledPoint(label, features) =>
+ LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
+ }
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
@@ -66,33 +71,39 @@ class LassoSuite extends FunSuite with LocalSparkContext {
}
test("Lasso local random SGD with initial weights") {
- val nPoints = 10000
+ val nPoints = 1000
val A = 2.0
val B = -1.5
val C = 1.0e-2
- val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42)
+ val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B, C), nPoints, 42)
+ .map { case LabeledPoint(label, features) =>
+ LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
+ }
+ val initialA = -1.0
val initialB = -1.0
val initialC = -1.0
- val initialWeights = Array(initialB,initialC)
+ val initialWeights = Vectors.dense(initialA, initialB, initialC)
- val testRDD = sc.parallelize(testData, 2)
- testRDD.cache()
+ val testRDD = sc.parallelize(testData, 2).cache()
val ls = new LassoWithSGD()
- ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
+ ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40)
val model = ls.run(testRDD, initialWeights)
-
val weight0 = model.weights(0)
val weight1 = model.weights(1)
- assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
- assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
- assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
+ val weight2 = model.weights(2)
+ assert(weight0 >= 1.9 && weight0 <= 2.1, weight0 + " not in [1.9, 2.1]")
+ assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]")
+ assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]")
val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17)
+ .map { case LabeledPoint(label, features) =>
+ LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
+ }
val validationRDD = sc.parallelize(validationData,2)
// Test prediction on RDD.
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 5d251bcbf35db..2f7d30708ce17 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.regression
import org.scalatest.FunSuite
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
class LinearRegressionSuite extends FunSuite with LocalSparkContext {
@@ -40,11 +41,12 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
val model = linReg.run(testRDD)
-
assert(model.intercept >= 2.5 && model.intercept <= 3.5)
- assert(model.weights.length === 2)
- assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0)
- assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0)
+
+ val weights = model.weights
+ assert(weights.size === 2)
+ assert(weights(0) >= 9.0 && weights(0) <= 11.0)
+ assert(weights(1) >= 9.0 && weights(1) <= 11.0)
val validationData = LinearDataGenerator.generateLinearInput(
3.0, Array(10.0, 10.0), 100, 17)
@@ -67,9 +69,11 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
val model = linReg.run(testRDD)
assert(model.intercept === 0.0)
- assert(model.weights.length === 2)
- assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0)
- assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0)
+
+ val weights = model.weights
+ assert(weights.size === 2)
+ assert(weights(0) >= 9.0 && weights(0) <= 11.0)
+ assert(weights(1) >= 9.0 && weights(1) <= 11.0)
val validationData = LinearDataGenerator.generateLinearInput(
0.0, Array(10.0, 10.0), 100, 17)
@@ -81,4 +85,40 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
+
+ // Test if we can correctly learn Y = 10*X1 + 10*X10000
+ test("sparse linear regression without intercept") {
+ val denseRDD = sc.parallelize(
+ LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42), 2)
+ val sparseRDD = denseRDD.map { case LabeledPoint(label, v) =>
+ val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1))))
+ LabeledPoint(label, sv)
+ }.cache()
+ val linReg = new LinearRegressionWithSGD().setIntercept(false)
+ linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
+
+ val model = linReg.run(sparseRDD)
+
+ assert(model.intercept === 0.0)
+
+ val weights = model.weights
+ assert(weights.size === 10000)
+ assert(weights(0) >= 9.0 && weights(0) <= 11.0)
+ assert(weights(9999) >= 9.0 && weights(9999) <= 11.0)
+
+ val validationData = LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 17)
+ val sparseValidationData = validationData.map { case LabeledPoint(label, v) =>
+ val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1))))
+ LabeledPoint(label, sv)
+ }
+ val sparseValidationRDD = sc.parallelize(sparseValidationData, 2)
+
+ // Test prediction on RDD.
+ validatePrediction(
+ model.predict(sparseValidationRDD.map(_.features)).collect(), sparseValidationData)
+
+ // Test prediction on Array.
+ validatePrediction(
+ sparseValidationData.map(row => model.predict(row.features)), sparseValidationData)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index b2044ed0d8066..f66fc6ea6c1ec 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -17,9 +17,10 @@
package org.apache.spark.mllib.regression
-import org.jblas.DoubleMatrix
import org.scalatest.FunSuite
+import org.jblas.DoubleMatrix
+
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
@@ -30,22 +31,22 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
}.reduceLeft(_ + _) / predictions.size
}
- test("regularization with skewed weights") {
- val nexamples = 200
- val nfeatures = 20
- val eps = 10
+ test("ridge regression can help avoid overfitting") {
+
+ // For small number of examples and large variance of error distribution,
+ // ridge regression should give smaller generalization error that linear regression.
+
+ val numExamples = 50
+ val numFeatures = 20
org.jblas.util.Random.seed(42)
// Pick weights as random values distributed uniformly in [-0.5, 0.5]
- val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5)
- // Set first two weights to eps
- w.put(0, 0, eps)
- w.put(1, 0, eps)
+ val w = DoubleMatrix.rand(numFeatures, 1).subi(0.5)
// Use half of data for training and other half for validation
- val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2*nexamples, 42, eps)
- val testData = data.take(nexamples)
- val validationData = data.takeRight(nexamples)
+ val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2 * numExamples, 42, 10.0)
+ val testData = data.take(numExamples)
+ val validationData = data.takeRight(numExamples)
val testRDD = sc.parallelize(testData, 2).cache()
val validationRDD = sc.parallelize(validationData, 2).cache()
@@ -67,7 +68,7 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
val ridgeErr = predictionError(
ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData)
- // Ridge CV-error should be lower than linear regression
+ // Ridge validation error should be lower than linear regression.
assert(ridgeErr < linearErr,
"ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
new file mode 100644
index 0000000000000..350130c914f26
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -0,0 +1,426 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
+import org.apache.spark.mllib.tree.model.Filter
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+import org.apache.spark.mllib.linalg.Vectors
+
+class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
+
+ @transient private var sc: SparkContext = _
+
+ override def beforeAll() {
+ sc = new SparkContext("local", "test")
+ }
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ test("split and bin calculation") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Gini, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(bins.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+ }
+
+ test("split and bin calculation for categorical variables") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Classification,
+ Gini,
+ maxDepth = 3,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(bins.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ // Check splits.
+
+ assert(splits(0)(0).feature === 0)
+ assert(splits(0)(0).threshold === Double.MinValue)
+ assert(splits(0)(0).featureType === Categorical)
+ assert(splits(0)(0).categories.length === 1)
+ assert(splits(0)(0).categories.contains(1.0))
+
+ assert(splits(0)(1).feature === 0)
+ assert(splits(0)(1).threshold === Double.MinValue)
+ assert(splits(0)(1).featureType === Categorical)
+ assert(splits(0)(1).categories.length === 2)
+ assert(splits(0)(1).categories.contains(1.0))
+ assert(splits(0)(1).categories.contains(0.0))
+
+ assert(splits(0)(2) === null)
+
+ assert(splits(1)(0).feature === 1)
+ assert(splits(1)(0).threshold === Double.MinValue)
+ assert(splits(1)(0).featureType === Categorical)
+ assert(splits(1)(0).categories.length === 1)
+ assert(splits(1)(0).categories.contains(0.0))
+
+ assert(splits(1)(1).feature === 1)
+ assert(splits(1)(1).threshold === Double.MinValue)
+ assert(splits(1)(1).featureType === Categorical)
+ assert(splits(1)(1).categories.length === 2)
+ assert(splits(1)(1).categories.contains(1.0))
+ assert(splits(1)(1).categories.contains(0.0))
+
+ assert(splits(1)(2) === null)
+
+ // Check bins.
+
+ assert(bins(0)(0).category === 1.0)
+ assert(bins(0)(0).lowSplit.categories.length === 0)
+ assert(bins(0)(0).highSplit.categories.length === 1)
+ assert(bins(0)(0).highSplit.categories.contains(1.0))
+
+ assert(bins(0)(1).category === 0.0)
+ assert(bins(0)(1).lowSplit.categories.length === 1)
+ assert(bins(0)(1).lowSplit.categories.contains(1.0))
+ assert(bins(0)(1).highSplit.categories.length === 2)
+ assert(bins(0)(1).highSplit.categories.contains(1.0))
+ assert(bins(0)(1).highSplit.categories.contains(0.0))
+
+ assert(bins(0)(2) === null)
+
+ assert(bins(1)(0).category === 0.0)
+ assert(bins(1)(0).lowSplit.categories.length === 0)
+ assert(bins(1)(0).highSplit.categories.length === 1)
+ assert(bins(1)(0).highSplit.categories.contains(0.0))
+
+ assert(bins(1)(1).category === 1.0)
+ assert(bins(1)(1).lowSplit.categories.length === 1)
+ assert(bins(1)(1).lowSplit.categories.contains(0.0))
+ assert(bins(1)(1).highSplit.categories.length === 2)
+ assert(bins(1)(1).highSplit.categories.contains(0.0))
+ assert(bins(1)(1).highSplit.categories.contains(1.0))
+
+ assert(bins(1)(2) === null)
+ }
+
+ test("split and bin calculations for categorical variables with no sample for one category") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Classification,
+ Gini,
+ maxDepth = 3,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+
+ // Check splits.
+
+ assert(splits(0)(0).feature === 0)
+ assert(splits(0)(0).threshold === Double.MinValue)
+ assert(splits(0)(0).featureType === Categorical)
+ assert(splits(0)(0).categories.length === 1)
+ assert(splits(0)(0).categories.contains(1.0))
+
+ assert(splits(0)(1).feature === 0)
+ assert(splits(0)(1).threshold === Double.MinValue)
+ assert(splits(0)(1).featureType === Categorical)
+ assert(splits(0)(1).categories.length === 2)
+ assert(splits(0)(1).categories.contains(1.0))
+ assert(splits(0)(1).categories.contains(0.0))
+
+ assert(splits(0)(2).feature === 0)
+ assert(splits(0)(2).threshold === Double.MinValue)
+ assert(splits(0)(2).featureType === Categorical)
+ assert(splits(0)(2).categories.length === 3)
+ assert(splits(0)(2).categories.contains(1.0))
+ assert(splits(0)(2).categories.contains(0.0))
+ assert(splits(0)(2).categories.contains(2.0))
+
+ assert(splits(0)(3) === null)
+
+ assert(splits(1)(0).feature === 1)
+ assert(splits(1)(0).threshold === Double.MinValue)
+ assert(splits(1)(0).featureType === Categorical)
+ assert(splits(1)(0).categories.length === 1)
+ assert(splits(1)(0).categories.contains(0.0))
+
+ assert(splits(1)(1).feature === 1)
+ assert(splits(1)(1).threshold === Double.MinValue)
+ assert(splits(1)(1).featureType === Categorical)
+ assert(splits(1)(1).categories.length === 2)
+ assert(splits(1)(1).categories.contains(1.0))
+ assert(splits(1)(1).categories.contains(0.0))
+
+ assert(splits(1)(2).feature === 1)
+ assert(splits(1)(2).threshold === Double.MinValue)
+ assert(splits(1)(2).featureType === Categorical)
+ assert(splits(1)(2).categories.length === 3)
+ assert(splits(1)(2).categories.contains(1.0))
+ assert(splits(1)(2).categories.contains(0.0))
+ assert(splits(1)(2).categories.contains(2.0))
+
+ assert(splits(1)(3) === null)
+
+ // Check bins.
+
+ assert(bins(0)(0).category === 1.0)
+ assert(bins(0)(0).lowSplit.categories.length === 0)
+ assert(bins(0)(0).highSplit.categories.length === 1)
+ assert(bins(0)(0).highSplit.categories.contains(1.0))
+
+ assert(bins(0)(1).category === 0.0)
+ assert(bins(0)(1).lowSplit.categories.length === 1)
+ assert(bins(0)(1).lowSplit.categories.contains(1.0))
+ assert(bins(0)(1).highSplit.categories.length === 2)
+ assert(bins(0)(1).highSplit.categories.contains(1.0))
+ assert(bins(0)(1).highSplit.categories.contains(0.0))
+
+ assert(bins(0)(2).category === 2.0)
+ assert(bins(0)(2).lowSplit.categories.length === 2)
+ assert(bins(0)(2).lowSplit.categories.contains(1.0))
+ assert(bins(0)(2).lowSplit.categories.contains(0.0))
+ assert(bins(0)(2).highSplit.categories.length === 3)
+ assert(bins(0)(2).highSplit.categories.contains(1.0))
+ assert(bins(0)(2).highSplit.categories.contains(0.0))
+ assert(bins(0)(2).highSplit.categories.contains(2.0))
+
+ assert(bins(0)(3) === null)
+
+ assert(bins(1)(0).category === 0.0)
+ assert(bins(1)(0).lowSplit.categories.length === 0)
+ assert(bins(1)(0).highSplit.categories.length === 1)
+ assert(bins(1)(0).highSplit.categories.contains(0.0))
+
+ assert(bins(1)(1).category === 1.0)
+ assert(bins(1)(1).lowSplit.categories.length === 1)
+ assert(bins(1)(1).lowSplit.categories.contains(0.0))
+ assert(bins(1)(1).highSplit.categories.length === 2)
+ assert(bins(1)(1).highSplit.categories.contains(0.0))
+ assert(bins(1)(1).highSplit.categories.contains(1.0))
+
+ assert(bins(1)(2).category === 2.0)
+ assert(bins(1)(2).lowSplit.categories.length === 2)
+ assert(bins(1)(2).lowSplit.categories.contains(0.0))
+ assert(bins(1)(2).lowSplit.categories.contains(1.0))
+ assert(bins(1)(2).highSplit.categories.length === 3)
+ assert(bins(1)(2).highSplit.categories.contains(0.0))
+ assert(bins(1)(2).highSplit.categories.contains(1.0))
+ assert(bins(1)(2).highSplit.categories.contains(2.0))
+
+ assert(bins(1)(3) === null)
+ }
+
+ test("classification stump with all categorical variables") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Classification,
+ Gini,
+ maxDepth = 3,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+
+ val split = bestSplits(0)._1
+ assert(split.categories.length === 1)
+ assert(split.categories.contains(1.0))
+ assert(split.featureType === Categorical)
+ assert(split.threshold === Double.MinValue)
+
+ val stats = bestSplits(0)._2
+ assert(stats.gain > 0)
+ assert(stats.predict > 0.4)
+ assert(stats.predict < 0.5)
+ assert(stats.impurity > 0.2)
+ }
+
+ test("regression stump with all categorical variables") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Regression,
+ Variance,
+ maxDepth = 3,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
+ val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+
+ val split = bestSplits(0)._1
+ assert(split.categories.length === 1)
+ assert(split.categories.contains(1.0))
+ assert(split.featureType === Categorical)
+ assert(split.threshold === Double.MinValue)
+
+ val stats = bestSplits(0)._2
+ assert(stats.gain > 0)
+ assert(stats.predict > 0.4)
+ assert(stats.predict < 0.5)
+ assert(stats.impurity > 0.2)
+ }
+
+ test("stump with fixed label 0 for Gini") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Gini, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+ assert(bestSplits.length === 1)
+ assert(bestSplits(0)._1.feature === 0)
+ assert(bestSplits(0)._1.threshold === 10)
+ assert(bestSplits(0)._2.gain === 0)
+ assert(bestSplits(0)._2.leftImpurity === 0)
+ assert(bestSplits(0)._2.rightImpurity === 0)
+ }
+
+ test("stump with fixed label 1 for Gini") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Gini, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+ assert(bestSplits.length === 1)
+ assert(bestSplits(0)._1.feature === 0)
+ assert(bestSplits(0)._1.threshold === 10)
+ assert(bestSplits(0)._2.gain === 0)
+ assert(bestSplits(0)._2.leftImpurity === 0)
+ assert(bestSplits(0)._2.rightImpurity === 0)
+ assert(bestSplits(0)._2.predict === 1)
+ }
+
+ test("stump with fixed label 0 for Entropy") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Entropy, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+ assert(bestSplits.length === 1)
+ assert(bestSplits(0)._1.feature === 0)
+ assert(bestSplits(0)._1.threshold === 10)
+ assert(bestSplits(0)._2.gain === 0)
+ assert(bestSplits(0)._2.leftImpurity === 0)
+ assert(bestSplits(0)._2.rightImpurity === 0)
+ assert(bestSplits(0)._2.predict === 0)
+ }
+
+ test("stump with fixed label 1 for Entropy") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Entropy, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+ assert(bestSplits.length === 1)
+ assert(bestSplits(0)._1.feature === 0)
+ assert(bestSplits(0)._1.threshold === 10)
+ assert(bestSplits(0)._2.gain === 0)
+ assert(bestSplits(0)._2.leftImpurity === 0)
+ assert(bestSplits(0)._2.rightImpurity === 0)
+ assert(bestSplits(0)._2.predict === 1)
+ }
+}
+
+object DecisionTreeSuite {
+
+ def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](1000)
+ for (i <- 0 until 1000){
+ val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
+ arr(i) = lp
+ }
+ arr
+ }
+
+ def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](1000)
+ for (i <- 0 until 1000){
+ val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i))
+ arr(i) = lp
+ }
+ arr
+ }
+
+ def generateCategoricalDataPoints(): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](1000)
+ for (i <- 0 until 1000){
+ if (i < 600){
+ arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0))
+ } else {
+ arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0))
+ }
+ }
+ arr
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 60f053b381305..27d41c7869aa0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -17,14 +17,20 @@
package org.apache.spark.mllib.util
+import java.io.File
+
import org.scalatest.FunSuite
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm,
squaredDistance => breezeSquaredDistance}
+import com.google.common.base.Charsets
+import com.google.common.io.Files
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils._
-class MLUtilsSuite extends FunSuite {
+class MLUtilsSuite extends FunSuite with LocalSparkContext {
test("epsilon computation") {
assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.")
@@ -49,4 +55,55 @@ class MLUtilsSuite extends FunSuite {
assert((fastSquaredDist2 - squaredDist) <= precision * squaredDist, s"failed with m = $m")
}
}
+
+ test("compute stats") {
+ val data = Seq.fill(3)(Seq(
+ LabeledPoint(1.0, Vectors.dense(1.0, 2.0, 3.0)),
+ LabeledPoint(0.0, Vectors.dense(3.0, 4.0, 5.0))
+ )).flatten
+ val rdd = sc.parallelize(data, 2)
+ val (meanLabel, mean, std) = MLUtils.computeStats(rdd, 3, 6)
+ assert(meanLabel === 0.5)
+ assert(mean === Vectors.dense(2.0, 3.0, 4.0))
+ assert(std === Vectors.dense(1.0, 1.0, 1.0))
+ }
+
+ test("loadLibSVMData") {
+ val lines =
+ """
+ |+1 1:1.0 3:2.0 5:3.0
+ |-1
+ |-1 2:4.0 4:5.0 6:6.0
+ """.stripMargin
+ val tempDir = Files.createTempDir()
+ val file = new File(tempDir.getPath, "part-00000")
+ Files.write(lines, file, Charsets.US_ASCII)
+ val path = tempDir.toURI.toString
+
+ val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, 6).collect()
+ val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect()
+
+ for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) {
+ assert(points.length === 3)
+ assert(points(0).label === 1.0)
+ assert(points(0).features === Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
+ assert(points(1).label == 0.0)
+ assert(points(1).features == Vectors.sparse(6, Seq()))
+ assert(points(2).label === 0.0)
+ assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
+ }
+
+ val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MLUtils.multiclassLabelParser).collect()
+ assert(multiclassPoints.length === 3)
+ assert(multiclassPoints(0).label === 1.0)
+ assert(multiclassPoints(1).label === -1.0)
+ assert(multiclassPoints(2).label === -1.0)
+
+ try {
+ file.delete()
+ tempDir.delete()
+ } catch {
+ case t: Throwable =>
+ }
+ }
}
diff --git a/pom.xml b/pom.xml
index 09a449d81453f..7d58060cba606 100644
--- a/pom.xml
+++ b/pom.xml
@@ -110,7 +110,7 @@
1.6
- 2.10.3
+ 2.10.4
2.10
0.13.0
org.spark-project.akka
@@ -380,7 +380,7 @@
lift-json_${scala.binary.version}
2.5.1
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 7457ff456ade4..c5c697e8e2427 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -152,7 +152,7 @@ object SparkBuild extends Build {
def sharedSettings = Defaults.defaultSettings ++ MimaBuild.mimaSettings(file(sparkHome)) ++ Seq(
organization := "org.apache.spark",
version := SPARK_VERSION,
- scalaVersion := "2.10.3",
+ scalaVersion := "2.10.4",
scalacOptions := Seq("-Xmax-classfile-name", "120", "-unchecked", "-deprecation",
"-target:" + SCALAC_JVM_VERSION),
javacOptions := Seq("-target", JAVAC_JVM_VERSION, "-source", JAVAC_JVM_VERSION),
diff --git a/project/plugins.sbt b/project/plugins.sbt
index 5aa8a1ec2409b..d787237ddc540 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -1,4 +1,4 @@
-scalaVersion := "2.10.3"
+scalaVersion := "2.10.4"
resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns)
diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala
index 5a307044ba123..0142256e90fb7 100644
--- a/project/project/SparkPluginBuild.scala
+++ b/project/project/SparkPluginBuild.scala
@@ -32,7 +32,7 @@ object SparkPluginDef extends Build {
name := "spark-style",
organization := "org.apache.spark",
version := sparkVersion,
- scalaVersion := "2.10.3",
+ scalaVersion := "2.10.4",
scalacOptions := Seq("-unchecked", "-deprecation"),
libraryDependencies ++= Dependencies.scalaStyle
)
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 19b90dfd6e167..d2f9cdb3f4298 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -87,18 +87,19 @@ class NaiveBayesModel(object):
>>> data = array([0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0]).reshape(3,3)
>>> model = NaiveBayes.train(sc.parallelize(data))
>>> model.predict(array([0.0, 1.0]))
- 0
+ 0.0
>>> model.predict(array([1.0, 0.0]))
- 1
+ 1.0
"""
- def __init__(self, pi, theta):
+ def __init__(self, labels, pi, theta):
+ self.labels = labels
self.pi = pi
self.theta = theta
def predict(self, x):
"""Return the most likely class for a data vector x"""
- return numpy.argmax(self.pi + dot(x, self.theta))
+ return self.labels[numpy.argmax(self.pi + dot(x, self.theta))]
class NaiveBayes(object):
@classmethod
@@ -122,7 +123,8 @@ def train(cls, data, lambda_=1.0):
ans = sc._jvm.PythonMLLibAPI().trainNaiveBayes(dataBytes._jrdd, lambda_)
return NaiveBayesModel(
_deserialize_double_vector(ans[0]),
- _deserialize_double_matrix(ans[1]))
+ _deserialize_double_vector(ans[1]),
+ _deserialize_double_matrix(ans[2]))
def _test():
diff --git a/sql/README.md b/sql/README.md
index 4192fecb92fb0..14d5555f0c713 100644
--- a/sql/README.md
+++ b/sql/README.md
@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.TestHive._
-Welcome to Scala version 2.10.3 (Java HotSpot(TM) 64-Bit Server VM, Java 1.7.0_45).
+Welcome to Scala version 2.10.4 (Java HotSpot(TM) 64-Bit Server VM, Java 1.7.0_45).
Type in expressions to have them evaluated.
Type :help for more information.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 976dda8d7e59a..5aaa63bf3b4b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -43,15 +43,25 @@ object ScalaReflection {
val params = t.member("": TermName).asMethod.paramss
StructType(
params.head.map(p => StructField(p.name.toString, schemaFor(p.typeSignature), true)))
+ // Need to decide if we actually need a special type here.
+ case t if t <:< typeOf[Array[Byte]] => BinaryType
+ case t if t <:< typeOf[Array[_]] =>
+ sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
ArrayType(schemaFor(elementType))
+ case t if t <:< typeOf[Map[_,_]] =>
+ val TypeRef(_, _, Seq(keyType, valueType)) = t
+ MapType(schemaFor(keyType), schemaFor(valueType))
case t if t <:< typeOf[String] => StringType
case t if t <:< definitions.IntTpe => IntegerType
case t if t <:< definitions.LongTpe => LongType
+ case t if t <:< definitions.FloatTpe => FloatType
case t if t <:< definitions.DoubleTpe => DoubleType
case t if t <:< definitions.ShortTpe => ShortType
case t if t <:< definitions.ByteTpe => ByteType
+ case t if t <:< definitions.BooleanTpe => BooleanType
+ case t if t <:< typeOf[BigDecimal] => DecimalType
}
implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 0c851c2ee2183..8de87594c8ab9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -181,7 +181,7 @@ class SqlParser extends StandardTokenParsers {
val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct)
val withOrder = o.map(o => Sort(o, withHaving)).getOrElse(withHaving)
- val withLimit = l.map { l => StopAfter(l, withOrder) }.getOrElse(withOrder)
+ val withLimit = l.map { l => Limit(l, withOrder) }.getOrElse(withOrder)
withLimit
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
index e09182dd8d5df..6b58b9322c4bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
@@ -31,6 +31,7 @@ trait Catalog {
alias: Option[String] = None): LogicalPlan
def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit
+ def unregisterTable(databaseName: Option[String], tableName: String): Unit
}
class SimpleCatalog extends Catalog {
@@ -40,7 +41,7 @@ class SimpleCatalog extends Catalog {
tables += ((tableName, plan))
}
- def dropTable(tableName: String) = tables -= tableName
+ def unregisterTable(databaseName: Option[String], tableName: String) = { tables -= tableName }
def lookupRelation(
databaseName: Option[String],
@@ -87,6 +88,10 @@ trait OverrideCatalog extends Catalog {
plan: LogicalPlan): Unit = {
overrides.put((databaseName, tableName), plan)
}
+
+ override def unregisterTable(databaseName: Option[String], tableName: String): Unit = {
+ overrides.remove((databaseName, tableName))
+ }
}
/**
@@ -104,4 +109,8 @@ object EmptyCatalog extends Catalog {
def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = {
throw new UnsupportedOperationException
}
+
+ def unregisterTable(databaseName: Option[String], tableName: String): Unit = {
+ throw new UnsupportedOperationException
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 9d16189deedfe..b39c2b32cc42c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -130,7 +130,7 @@ case class Aggregate(
def references = child.references
}
-case class StopAfter(limit: Expression, child: LogicalPlan) extends UnaryNode {
+case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode {
def output = child.output
def references = limit.references
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index f950ea08ec57a..f4bf00f4cffa6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -26,8 +26,9 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.dsl
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{Subquery, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
import org.apache.spark.sql.execution._
/**
@@ -111,11 +112,40 @@ class SQLContext(@transient val sparkContext: SparkContext)
result
}
+ /** Returns the specified table as a SchemaRDD */
+ def table(tableName: String): SchemaRDD =
+ new SchemaRDD(this, catalog.lookupRelation(None, tableName))
+
+ /** Caches the specified table in-memory. */
+ def cacheTable(tableName: String): Unit = {
+ val currentTable = catalog.lookupRelation(None, tableName)
+ val asInMemoryRelation =
+ InMemoryColumnarTableScan(currentTable.output, executePlan(currentTable).executedPlan)
+
+ catalog.registerTable(None, tableName, SparkLogicalPlan(asInMemoryRelation))
+ }
+
+ /** Removes the specified table from the in-memory cache. */
+ def uncacheTable(tableName: String): Unit = {
+ EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) match {
+ // This is kind of a hack to make sure that if this was just an RDD registered as a table,
+ // we reregister the RDD as a table.
+ case SparkLogicalPlan(inMem @ InMemoryColumnarTableScan(_, e: ExistingRdd)) =>
+ inMem.cachedColumnBuffers.unpersist()
+ catalog.unregisterTable(None, tableName)
+ catalog.registerTable(None, tableName, SparkLogicalPlan(e))
+ case SparkLogicalPlan(inMem: InMemoryColumnarTableScan) =>
+ inMem.cachedColumnBuffers.unpersist()
+ catalog.unregisterTable(None, tableName)
+ case plan => throw new IllegalArgumentException(s"Table $tableName is not cached: $plan")
+ }
+ }
+
protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext = self.sparkContext
val strategies: Seq[Strategy] =
- TopK ::
+ TakeOrdered ::
PartialAggregation ::
HashJoin ::
ParquetOperations ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
index e0c98ecdf8f22..ffd4894b5213d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
@@ -21,7 +21,7 @@ import java.nio.{ByteOrder, ByteBuffer}
import org.apache.spark.sql.catalyst.types.{BinaryType, NativeType, DataType}
import org.apache.spark.sql.catalyst.expressions.MutableRow
-import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor
/**
* An `Iterator` like trait used to extract values from columnar byte buffer. When a value is
@@ -41,121 +41,66 @@ private[sql] trait ColumnAccessor {
protected def underlyingBuffer: ByteBuffer
}
-private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](buffer: ByteBuffer)
+private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](
+ protected val buffer: ByteBuffer,
+ protected val columnType: ColumnType[T, JvmType])
extends ColumnAccessor {
protected def initialize() {}
- def columnType: ColumnType[T, JvmType]
-
def hasNext = buffer.hasRemaining
def extractTo(row: MutableRow, ordinal: Int) {
- doExtractTo(row, ordinal)
+ columnType.setField(row, ordinal, extractSingle(buffer))
}
- protected def doExtractTo(row: MutableRow, ordinal: Int)
+ def extractSingle(buffer: ByteBuffer): JvmType = columnType.extract(buffer)
protected def underlyingBuffer = buffer
}
private[sql] abstract class NativeColumnAccessor[T <: NativeType](
- buffer: ByteBuffer,
- val columnType: NativeColumnType[T])
- extends BasicColumnAccessor[T, T#JvmType](buffer)
+ override protected val buffer: ByteBuffer,
+ override protected val columnType: NativeColumnType[T])
+ extends BasicColumnAccessor(buffer, columnType)
with NullableColumnAccessor
+ with CompressibleColumnAccessor[T]
private[sql] class BooleanColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, BOOLEAN) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setBoolean(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, BOOLEAN)
private[sql] class IntColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, INT) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setInt(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, INT)
private[sql] class ShortColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, SHORT) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setShort(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, SHORT)
private[sql] class LongColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, LONG) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setLong(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, LONG)
private[sql] class ByteColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, BYTE) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setByte(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, BYTE)
private[sql] class DoubleColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, DOUBLE) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setDouble(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, DOUBLE)
private[sql] class FloatColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, FLOAT) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setFloat(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, FLOAT)
private[sql] class StringColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, STRING) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setString(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, STRING)
private[sql] class BinaryColumnAccessor(buffer: ByteBuffer)
- extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer)
- with NullableColumnAccessor {
-
- def columnType = BINARY
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row(ordinal) = columnType.extract(buffer)
- }
-}
+ extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY)
+ with NullableColumnAccessor
private[sql] class GenericColumnAccessor(buffer: ByteBuffer)
- extends BasicColumnAccessor[DataType, Array[Byte]](buffer)
- with NullableColumnAccessor {
-
- def columnType = GENERIC
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- val serialized = columnType.extract(buffer)
- row(ordinal) = SparkSqlSerializer.deserialize[Any](serialized)
- }
-}
+ extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC)
+ with NullableColumnAccessor
private[sql] object ColumnAccessor {
- def apply(b: ByteBuffer): ColumnAccessor = {
- // The first 4 bytes in the buffer indicates the column type.
- val buffer = b.duplicate().order(ByteOrder.nativeOrder())
+ def apply(buffer: ByteBuffer): ColumnAccessor = {
+ // The first 4 bytes in the buffer indicate the column type.
val columnTypeId = buffer.getInt()
columnTypeId match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index 3e622adfd3d6a..048ee66bff44b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -22,7 +22,7 @@ import java.nio.{ByteBuffer, ByteOrder}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar.ColumnBuilder._
-import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder}
private[sql] trait ColumnBuilder {
/**
@@ -30,37 +30,44 @@ private[sql] trait ColumnBuilder {
*/
def initialize(initialSize: Int, columnName: String = "")
+ /**
+ * Appends `row(ordinal)` to the column builder.
+ */
def appendFrom(row: Row, ordinal: Int)
+ /**
+ * Column statistics information
+ */
+ def columnStats: ColumnStats[_, _]
+
+ /**
+ * Returns the final columnar byte buffer.
+ */
def build(): ByteBuffer
}
-private[sql] abstract class BasicColumnBuilder[T <: DataType, JvmType] extends ColumnBuilder {
+private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
+ val columnStats: ColumnStats[T, JvmType],
+ val columnType: ColumnType[T, JvmType])
+ extends ColumnBuilder {
- private var columnName: String = _
- protected var buffer: ByteBuffer = _
+ protected var columnName: String = _
- def columnType: ColumnType[T, JvmType]
+ protected var buffer: ByteBuffer = _
override def initialize(initialSize: Int, columnName: String = "") = {
val size = if (initialSize == 0) DEFAULT_INITIAL_BUFFER_SIZE else initialSize
this.columnName = columnName
- buffer = ByteBuffer.allocate(4 + 4 + size * columnType.defaultSize)
+
+ // Reserves 4 bytes for column type ID
+ buffer = ByteBuffer.allocate(4 + size * columnType.defaultSize)
buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId)
}
- // Have to give a concrete implementation to make mixin possible
override def appendFrom(row: Row, ordinal: Int) {
- doAppendFrom(row, ordinal)
- }
-
- // Concrete `ColumnBuilder`s can override this method to append values
- protected def doAppendFrom(row: Row, ordinal: Int)
-
- // Helper method to append primitive values (to avoid boxing cost)
- protected def appendValue(v: JvmType) {
- buffer = ensureFreeSpace(buffer, columnType.actualSize(v))
- columnType.append(v, buffer)
+ val field = columnType.getField(row, ordinal)
+ buffer = ensureFreeSpace(buffer, columnType.actualSize(field))
+ columnType.append(field, buffer)
}
override def build() = {
@@ -69,83 +76,39 @@ private[sql] abstract class BasicColumnBuilder[T <: DataType, JvmType] extends C
}
}
-private[sql] abstract class NativeColumnBuilder[T <: NativeType](
- val columnType: NativeColumnType[T])
- extends BasicColumnBuilder[T, T#JvmType]
+private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType](
+ columnType: ColumnType[T, JvmType])
+ extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType)
with NullableColumnBuilder
-private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(BOOLEAN) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getBoolean(ordinal))
- }
-}
-
-private[sql] class IntColumnBuilder extends NativeColumnBuilder(INT) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getInt(ordinal))
- }
-}
+private[sql] abstract class NativeColumnBuilder[T <: NativeType](
+ override val columnStats: NativeColumnStats[T],
+ override val columnType: NativeColumnType[T])
+ extends BasicColumnBuilder[T, T#JvmType](columnStats, columnType)
+ with NullableColumnBuilder
+ with AllCompressionSchemes
+ with CompressibleColumnBuilder[T]
-private[sql] class ShortColumnBuilder extends NativeColumnBuilder(SHORT) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getShort(ordinal))
- }
-}
+private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN)
-private[sql] class LongColumnBuilder extends NativeColumnBuilder(LONG) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getLong(ordinal))
- }
-}
+private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT)
-private[sql] class ByteColumnBuilder extends NativeColumnBuilder(BYTE) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getByte(ordinal))
- }
-}
+private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT)
-private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(DOUBLE) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getDouble(ordinal))
- }
-}
+private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG)
-private[sql] class FloatColumnBuilder extends NativeColumnBuilder(FLOAT) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getFloat(ordinal))
- }
-}
+private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE)
-private[sql] class StringColumnBuilder extends NativeColumnBuilder(STRING) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getString(ordinal))
- }
-}
+private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE)
-private[sql] class BinaryColumnBuilder
- extends BasicColumnBuilder[BinaryType.type, Array[Byte]]
- with NullableColumnBuilder {
+private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT)
- def columnType = BINARY
+private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING)
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row(ordinal).asInstanceOf[Array[Byte]])
- }
-}
+private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(BINARY)
// TODO (lian) Add support for array, struct and map
-private[sql] class GenericColumnBuilder
- extends BasicColumnBuilder[DataType, Array[Byte]]
- with NullableColumnBuilder {
-
- def columnType = GENERIC
-
- override def doAppendFrom(row: Row, ordinal: Int) {
- val serialized = SparkSqlSerializer.serialize(row(ordinal))
- buffer = ColumnBuilder.ensureFreeSpace(buffer, columnType.actualSize(serialized))
- columnType.append(serialized, buffer)
- }
-}
+private[sql] class GenericColumnBuilder extends ComplexColumnBuilder(GENERIC)
private[sql] object ColumnBuilder {
val DEFAULT_INITIAL_BUFFER_SIZE = 10 * 1024 * 104
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
new file mode 100644
index 0000000000000..30c6bdc7912fc
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -0,0 +1,360 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.types._
+
+private[sql] sealed abstract class ColumnStats[T <: DataType, JvmType] extends Serializable {
+ /**
+ * Closed lower bound of this column.
+ */
+ def lowerBound: JvmType
+
+ /**
+ * Closed upper bound of this column.
+ */
+ def upperBound: JvmType
+
+ /**
+ * Gathers statistics information from `row(ordinal)`.
+ */
+ def gatherStats(row: Row, ordinal: Int)
+
+ /**
+ * Returns `true` if `lower <= row(ordinal) <= upper`.
+ */
+ def contains(row: Row, ordinal: Int): Boolean
+
+ /**
+ * Returns `true` if `row(ordinal) < upper` holds.
+ */
+ def isAbove(row: Row, ordinal: Int): Boolean
+
+ /**
+ * Returns `true` if `lower < row(ordinal)` holds.
+ */
+ def isBelow(row: Row, ordinal: Int): Boolean
+
+ /**
+ * Returns `true` if `row(ordinal) <= upper` holds.
+ */
+ def isAtOrAbove(row: Row, ordinal: Int): Boolean
+
+ /**
+ * Returns `true` if `lower <= row(ordinal)` holds.
+ */
+ def isAtOrBelow(row: Row, ordinal: Int): Boolean
+}
+
+private[sql] sealed abstract class NativeColumnStats[T <: NativeType]
+ extends ColumnStats[T, T#JvmType] {
+
+ type JvmType = T#JvmType
+
+ protected var (_lower, _upper) = initialBounds
+
+ def initialBounds: (JvmType, JvmType)
+
+ protected def columnType: NativeColumnType[T]
+
+ override def lowerBound: T#JvmType = _lower
+
+ override def upperBound: T#JvmType = _upper
+
+ override def isAtOrAbove(row: Row, ordinal: Int) = {
+ contains(row, ordinal) || isAbove(row, ordinal)
+ }
+
+ override def isAtOrBelow(row: Row, ordinal: Int) = {
+ contains(row, ordinal) || isBelow(row, ordinal)
+ }
+}
+
+private[sql] class NoopColumnStats[T <: DataType, JvmType] extends ColumnStats[T, JvmType] {
+ override def isAtOrBelow(row: Row, ordinal: Int) = true
+
+ override def isAtOrAbove(row: Row, ordinal: Int) = true
+
+ override def isBelow(row: Row, ordinal: Int) = true
+
+ override def isAbove(row: Row, ordinal: Int) = true
+
+ override def contains(row: Row, ordinal: Int) = true
+
+ override def gatherStats(row: Row, ordinal: Int) {}
+
+ override def upperBound = null.asInstanceOf[JvmType]
+
+ override def lowerBound = null.asInstanceOf[JvmType]
+}
+
+private[sql] abstract class BasicColumnStats[T <: NativeType](
+ protected val columnType: NativeColumnType[T])
+ extends NativeColumnStats[T]
+
+private[sql] class BooleanColumnStats extends BasicColumnStats(BOOLEAN) {
+ override def initialBounds = (true, false)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] class ByteColumnStats extends BasicColumnStats(BYTE) {
+ override def initialBounds = (Byte.MaxValue, Byte.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] class ShortColumnStats extends BasicColumnStats(SHORT) {
+ override def initialBounds = (Short.MaxValue, Short.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] class LongColumnStats extends BasicColumnStats(LONG) {
+ override def initialBounds = (Long.MaxValue, Long.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] class DoubleColumnStats extends BasicColumnStats(DOUBLE) {
+ override def initialBounds = (Double.MaxValue, Double.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] class FloatColumnStats extends BasicColumnStats(FLOAT) {
+ override def initialBounds = (Float.MaxValue, Float.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] object IntColumnStats {
+ val UNINITIALIZED = 0
+ val INITIALIZED = 1
+ val ASCENDING = 2
+ val DESCENDING = 3
+ val UNORDERED = 4
+}
+
+/**
+ * Statistical information for `Int` columns. More information is collected since `Int` is
+ * frequently used. Extra information include:
+ *
+ * - Ordering state (ascending/descending/unordered), may be used to decide whether binary search
+ * is applicable when searching elements.
+ * - Maximum delta between adjacent elements, may be used to guide the `IntDelta` compression
+ * scheme.
+ *
+ * (This two kinds of information are not used anywhere yet and might be removed later.)
+ */
+private[sql] class IntColumnStats extends BasicColumnStats(INT) {
+ import IntColumnStats._
+
+ private var orderedState = UNINITIALIZED
+ private var lastValue: Int = _
+ private var _maxDelta: Int = _
+
+ def isAscending = orderedState != DESCENDING && orderedState != UNORDERED
+ def isDescending = orderedState != ASCENDING && orderedState != UNORDERED
+ def isOrdered = isAscending || isDescending
+ def maxDelta = _maxDelta
+
+ override def initialBounds = (Int.MaxValue, Int.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+
+ orderedState = orderedState match {
+ case UNINITIALIZED =>
+ lastValue = field
+ INITIALIZED
+
+ case INITIALIZED =>
+ // If all the integers in the column are the same, ordered state is set to Ascending.
+ // TODO (lian) Confirm whether this is the standard behaviour.
+ val nextState = if (field >= lastValue) ASCENDING else DESCENDING
+ _maxDelta = math.abs(field - lastValue)
+ lastValue = field
+ nextState
+
+ case ASCENDING if field < lastValue =>
+ UNORDERED
+
+ case DESCENDING if field > lastValue =>
+ UNORDERED
+
+ case state @ (ASCENDING | DESCENDING) =>
+ _maxDelta = _maxDelta.max(field - lastValue)
+ lastValue = field
+ state
+
+ case _ =>
+ orderedState
+ }
+ }
+}
+
+private[sql] class StringColumnStats extends BasicColumnStats(STRING) {
+ override def initialBounds = (null, null)
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field
+ if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ !(upperBound eq null) && {
+ val field = columnType.getField(row, ordinal)
+ lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0
+ }
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ !(upperBound eq null) && {
+ val field = columnType.getField(row, ordinal)
+ field.compareTo(upperBound) < 0
+ }
+ }
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ !(lowerBound eq null) && {
+ val field = columnType.getField(row, ordinal)
+ lowerBound.compareTo(field) < 0
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index a452b86f0cda3..5be76890afe31 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -19,7 +19,12 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.execution.SparkSqlSerializer
/**
* An abstract class that represents type of a column. Used to append/extract Java objects into/from
@@ -50,10 +55,24 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
*/
def actualSize(v: JvmType): Int = defaultSize
+ /**
+ * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs
+ * whenever possible.
+ */
+ def getField(row: Row, ordinal: Int): JvmType
+
+ /**
+ * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing
+ * costs whenever possible.
+ */
+ def setField(row: MutableRow, ordinal: Int, value: JvmType)
+
/**
* Creates a duplicated copy of the value.
*/
def clone(v: JvmType): JvmType = v
+
+ override def toString = getClass.getSimpleName.stripSuffix("$")
}
private[sql] abstract class NativeColumnType[T <: NativeType](
@@ -65,7 +84,7 @@ private[sql] abstract class NativeColumnType[T <: NativeType](
/**
* Scala TypeTag. Can be used to create primitive arrays and hash tables.
*/
- def scalaTag = dataType.tag
+ def scalaTag: TypeTag[dataType.JvmType] = dataType.tag
}
private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) {
@@ -76,6 +95,12 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) {
def extract(buffer: ByteBuffer) = {
buffer.getInt()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Int) {
+ row.setInt(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getInt(ordinal)
}
private[sql] object LONG extends NativeColumnType(LongType, 1, 8) {
@@ -86,6 +111,12 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) {
override def extract(buffer: ByteBuffer) = {
buffer.getLong()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Long) {
+ row.setLong(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getLong(ordinal)
}
private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) {
@@ -96,6 +127,12 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) {
override def extract(buffer: ByteBuffer) = {
buffer.getFloat()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Float) {
+ row.setFloat(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getFloat(ordinal)
}
private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) {
@@ -106,6 +143,12 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) {
override def extract(buffer: ByteBuffer) = {
buffer.getDouble()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Double) {
+ row.setDouble(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getDouble(ordinal)
}
private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) {
@@ -116,6 +159,12 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) {
override def extract(buffer: ByteBuffer) = {
if (buffer.get() == 1) true else false
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Boolean) {
+ row.setBoolean(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getBoolean(ordinal)
}
private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) {
@@ -126,6 +175,12 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) {
override def extract(buffer: ByteBuffer) = {
buffer.get()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Byte) {
+ row.setByte(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getByte(ordinal)
}
private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) {
@@ -136,6 +191,12 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) {
override def extract(buffer: ByteBuffer) = {
buffer.getShort()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Short) {
+ row.setShort(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getShort(ordinal)
}
private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
@@ -152,6 +213,12 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
buffer.get(stringBytes, 0, length)
new String(stringBytes)
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: String) {
+ row.setString(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getString(ordinal)
}
private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
@@ -173,15 +240,27 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
}
}
-private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](8, 16)
+private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](8, 16) {
+ override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) {
+ row(ordinal) = value
+ }
+
+ override def getField(row: Row, ordinal: Int) = row(ordinal).asInstanceOf[Array[Byte]]
+}
// Used to process generic objects (all types other than those listed above). Objects should be
// serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized
// byte array.
-private[sql] object GENERIC extends ByteArrayColumnType[DataType](9, 16)
+private[sql] object GENERIC extends ByteArrayColumnType[DataType](9, 16) {
+ override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) {
+ row(ordinal) = SparkSqlSerializer.deserialize[Any](value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = SparkSqlSerializer.serialize(row(ordinal))
+}
private[sql] object ColumnType {
- implicit def dataTypeToColumnType(dataType: DataType): ColumnType[_, _] = {
+ def apply(dataType: DataType): ColumnType[_, _] = {
dataType match {
case IntegerType => INT
case LongType => LONG
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
similarity index 93%
rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index f853759e5a306..8a24733047423 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -21,9 +21,6 @@ import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute}
import org.apache.spark.sql.execution.{SparkPlan, LeafNode}
import org.apache.spark.sql.Row
-/* Implicit conversions */
-import org.apache.spark.sql.columnar.ColumnType._
-
private[sql] case class InMemoryColumnarTableScan(attributes: Seq[Attribute], child: SparkPlan)
extends LeafNode {
@@ -32,8 +29,8 @@ private[sql] case class InMemoryColumnarTableScan(attributes: Seq[Attribute], ch
lazy val cachedColumnBuffers = {
val output = child.output
val cached = child.execute().mapPartitions { iterator =>
- val columnBuilders = output.map { a =>
- ColumnBuilder(a.dataType.typeId, 0, a.name)
+ val columnBuilders = output.map { attribute =>
+ ColumnBuilder(ColumnType(attribute.dataType).typeId, 0, attribute.name)
}.toArray
var row: Row = null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
index 2970c609b928d..7d49ab07f7a53 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
@@ -29,7 +29,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor {
private var nextNullIndex: Int = _
private var pos: Int = 0
- abstract override def initialize() {
+ abstract override protected def initialize() {
nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder())
nullCount = nullsBuffer.getInt()
nextNullIndex = if (nullCount > 0) nullsBuffer.getInt() else -1
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
index 048d1f05c7df2..2a3b6fc1e46d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
@@ -22,10 +22,18 @@ import java.nio.{ByteBuffer, ByteOrder}
import org.apache.spark.sql.Row
/**
- * Builds a nullable column. The byte buffer of a nullable column contains:
- * - 4 bytes for the null count (number of nulls)
- * - positions for each null, in ascending order
- * - the non-null data (column data type, compression type, data...)
+ * A stackable trait used for building byte buffer for a column containing null values. Memory
+ * layout of the final byte buffer is:
+ * {{{
+ * .----------------------- Column type ID (4 bytes)
+ * | .------------------- Null count N (4 bytes)
+ * | | .--------------- Null positions (4 x N bytes, empty if null count is zero)
+ * | | | .--------- Non-null elements
+ * V V V V
+ * +---+---+-----+---------+
+ * | | | ... | ... ... |
+ * +---+---+-----+---------+
+ * }}}
*/
private[sql] trait NullableColumnBuilder extends ColumnBuilder {
private var nulls: ByteBuffer = _
@@ -59,19 +67,8 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder {
nulls.limit(nullDataLen)
nulls.rewind()
- // Column type ID is moved to the front, follows the null count, then non-null data
- //
- // +---------+
- // | 4 bytes | Column type ID
- // +---------+
- // | 4 bytes | Null count
- // +---------+
- // | ... | Null positions (if null count is not zero)
- // +---------+
- // | ... | Non-null part (without column type ID)
- // +---------+
val buffer = ByteBuffer
- .allocate(4 + nullDataLen + nonNulls.limit)
+ .allocate(4 + 4 + nullDataLen + nonNulls.remaining())
.order(ByteOrder.nativeOrder())
.putInt(typeId)
.putInt(nullCount)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala
new file mode 100644
index 0000000000000..878cb84de106f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor}
+
+private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAccessor {
+ this: NativeColumnAccessor[T] =>
+
+ private var decoder: Decoder[T] = _
+
+ abstract override protected def initialize() = {
+ super.initialize()
+ decoder = CompressionScheme(underlyingBuffer.getInt()).decoder(buffer, columnType)
+ }
+
+ abstract override def extractSingle(buffer: ByteBuffer): T#JvmType = decoder.next()
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
new file mode 100644
index 0000000000000..3ac4b358ddf83
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import java.nio.{ByteBuffer, ByteOrder}
+
+import org.apache.spark.sql.{Logging, Row}
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder}
+
+/**
+ * A stackable trait that builds optionally compressed byte buffer for a column. Memory layout of
+ * the final byte buffer is:
+ * {{{
+ * .--------------------------- Column type ID (4 bytes)
+ * | .----------------------- Null count N (4 bytes)
+ * | | .------------------- Null positions (4 x N bytes, empty if null count is zero)
+ * | | | .------------- Compression scheme ID (4 bytes)
+ * | | | | .--------- Compressed non-null elements
+ * V V V V V
+ * +---+---+-----+---+---------+
+ * | | | ... | | ... ... |
+ * +---+---+-----+---+---------+
+ * \-----------/ \-----------/
+ * header body
+ * }}}
+ */
+private[sql] trait CompressibleColumnBuilder[T <: NativeType]
+ extends ColumnBuilder with Logging {
+
+ this: NativeColumnBuilder[T] with WithCompressionSchemes =>
+
+ import CompressionScheme._
+
+ val compressionEncoders = schemes.filter(_.supports(columnType)).map(_.encoder)
+
+ protected def isWorthCompressing(encoder: Encoder) = {
+ encoder.compressionRatio < 0.8
+ }
+
+ private def gatherCompressibilityStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+
+ var i = 0
+ while (i < compressionEncoders.length) {
+ compressionEncoders(i).gatherCompressibilityStats(field, columnType)
+ i += 1
+ }
+ }
+
+ abstract override def appendFrom(row: Row, ordinal: Int) {
+ super.appendFrom(row, ordinal)
+ gatherCompressibilityStats(row, ordinal)
+ }
+
+ abstract override def build() = {
+ val rawBuffer = super.build()
+ val encoder = {
+ val candidate = compressionEncoders.minBy(_.compressionRatio)
+ if (isWorthCompressing(candidate)) candidate else PassThrough.encoder
+ }
+
+ val headerSize = columnHeaderSize(rawBuffer)
+ val compressedSize = if (encoder.compressedSize == 0) {
+ rawBuffer.limit - headerSize
+ } else {
+ encoder.compressedSize
+ }
+
+ // Reserves 4 bytes for compression scheme ID
+ val compressedBuffer = ByteBuffer
+ .allocate(headerSize + 4 + compressedSize)
+ .order(ByteOrder.nativeOrder)
+
+ copyColumnHeader(rawBuffer, compressedBuffer)
+
+ logger.info(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}")
+ encoder.compress(rawBuffer, compressedBuffer, columnType)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
new file mode 100644
index 0000000000000..d3a4ac8df926b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType}
+
+private[sql] trait Encoder {
+ def gatherCompressibilityStats[T <: NativeType](
+ value: T#JvmType,
+ columnType: ColumnType[T, T#JvmType]) {}
+
+ def compressedSize: Int
+
+ def uncompressedSize: Int
+
+ def compressionRatio: Double = {
+ if (uncompressedSize > 0) compressedSize.toDouble / uncompressedSize else 1.0
+ }
+
+ def compress[T <: NativeType](
+ from: ByteBuffer,
+ to: ByteBuffer,
+ columnType: ColumnType[T, T#JvmType]): ByteBuffer
+}
+
+private[sql] trait Decoder[T <: NativeType] extends Iterator[T#JvmType]
+
+private[sql] trait CompressionScheme {
+ def typeId: Int
+
+ def supports(columnType: ColumnType[_, _]): Boolean
+
+ def encoder: Encoder
+
+ def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T]
+}
+
+private[sql] trait WithCompressionSchemes {
+ def schemes: Seq[CompressionScheme]
+}
+
+private[sql] trait AllCompressionSchemes extends WithCompressionSchemes {
+ override val schemes: Seq[CompressionScheme] = {
+ Seq(PassThrough, RunLengthEncoding, DictionaryEncoding)
+ }
+}
+
+private[sql] object CompressionScheme {
+ def apply(typeId: Int): CompressionScheme = typeId match {
+ case PassThrough.typeId => PassThrough
+ case _ => throw new UnsupportedOperationException()
+ }
+
+ def copyColumnHeader(from: ByteBuffer, to: ByteBuffer) {
+ // Writes column type ID
+ to.putInt(from.getInt())
+
+ // Writes null count
+ val nullCount = from.getInt()
+ to.putInt(nullCount)
+
+ // Writes null positions
+ var i = 0
+ while (i < nullCount) {
+ to.putInt(from.getInt())
+ i += 1
+ }
+ }
+
+ def columnHeaderSize(columnBuffer: ByteBuffer): Int = {
+ val header = columnBuffer.duplicate()
+ val nullCount = header.getInt(4)
+ // Column type ID + null count + null positions
+ 4 + 4 + 4 * nullCount
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
new file mode 100644
index 0000000000000..dc2c153faf8ad
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
@@ -0,0 +1,288 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.runtimeMirror
+
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar._
+
+private[sql] case object PassThrough extends CompressionScheme {
+ override val typeId = 0
+
+ override def supports(columnType: ColumnType[_, _]) = true
+
+ override def encoder = new this.Encoder
+
+ override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
+ new this.Decoder(buffer, columnType)
+ }
+
+ class Encoder extends compression.Encoder {
+ override def uncompressedSize = 0
+
+ override def compressedSize = 0
+
+ override def compress[T <: NativeType](
+ from: ByteBuffer,
+ to: ByteBuffer,
+ columnType: ColumnType[T, T#JvmType]) = {
+
+ // Writes compression type ID and copies raw contents
+ to.putInt(PassThrough.typeId).put(from).rewind()
+ to
+ }
+ }
+
+ class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
+ extends compression.Decoder[T] {
+
+ override def next() = columnType.extract(buffer)
+
+ override def hasNext = buffer.hasRemaining
+ }
+}
+
+private[sql] case object RunLengthEncoding extends CompressionScheme {
+ override def typeId = 1
+
+ override def encoder = new this.Encoder
+
+ override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
+ new this.Decoder(buffer, columnType)
+ }
+
+ override def supports(columnType: ColumnType[_, _]) = columnType match {
+ case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true
+ case _ => false
+ }
+
+ class Encoder extends compression.Encoder {
+ private var _uncompressedSize = 0
+ private var _compressedSize = 0
+
+ // Using `MutableRow` to store the last value to avoid boxing/unboxing cost.
+ private val lastValue = new GenericMutableRow(1)
+ private var lastRun = 0
+
+ override def uncompressedSize = _uncompressedSize
+
+ override def compressedSize = _compressedSize
+
+ override def gatherCompressibilityStats[T <: NativeType](
+ value: T#JvmType,
+ columnType: ColumnType[T, T#JvmType]) {
+
+ val actualSize = columnType.actualSize(value)
+ _uncompressedSize += actualSize
+
+ if (lastValue.isNullAt(0)) {
+ columnType.setField(lastValue, 0, value)
+ lastRun = 1
+ _compressedSize += actualSize + 4
+ } else {
+ if (columnType.getField(lastValue, 0) == value) {
+ lastRun += 1
+ } else {
+ _compressedSize += actualSize + 4
+ columnType.setField(lastValue, 0, value)
+ lastRun = 1
+ }
+ }
+ }
+
+ override def compress[T <: NativeType](
+ from: ByteBuffer,
+ to: ByteBuffer,
+ columnType: ColumnType[T, T#JvmType]) = {
+
+ to.putInt(RunLengthEncoding.typeId)
+
+ if (from.hasRemaining) {
+ var currentValue = columnType.extract(from)
+ var currentRun = 1
+
+ while (from.hasRemaining) {
+ val value = columnType.extract(from)
+
+ if (value == currentValue) {
+ currentRun += 1
+ } else {
+ // Writes current run
+ columnType.append(currentValue, to)
+ to.putInt(currentRun)
+
+ // Resets current run
+ currentValue = value
+ currentRun = 1
+ }
+ }
+
+ // Writes the last run
+ columnType.append(currentValue, to)
+ to.putInt(currentRun)
+ }
+
+ to.rewind()
+ to
+ }
+ }
+
+ class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
+ extends compression.Decoder[T] {
+
+ private var run = 0
+ private var valueCount = 0
+ private var currentValue: T#JvmType = _
+
+ override def next() = {
+ if (valueCount == run) {
+ currentValue = columnType.extract(buffer)
+ run = buffer.getInt()
+ valueCount = 1
+ } else {
+ valueCount += 1
+ }
+
+ currentValue
+ }
+
+ override def hasNext = buffer.hasRemaining
+ }
+}
+
+private[sql] case object DictionaryEncoding extends CompressionScheme {
+ override def typeId: Int = 2
+
+ // 32K unique values allowed
+ private val MAX_DICT_SIZE = Short.MaxValue - 1
+
+ override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
+ new this.Decoder[T](buffer, columnType)
+ }
+
+ override def encoder = new this.Encoder
+
+ override def supports(columnType: ColumnType[_, _]) = columnType match {
+ case INT | LONG | STRING => true
+ case _ => false
+ }
+
+ class Encoder extends compression.Encoder{
+ // Size of the input, uncompressed, in bytes. Note that we only count until the dictionary
+ // overflows.
+ private var _uncompressedSize = 0
+
+ // If the number of distinct elements is too large, we discard the use of dictionary encoding
+ // and set the overflow flag to true.
+ private var overflow = false
+
+ // Total number of elements.
+ private var count = 0
+
+ // The reverse mapping of _dictionary, i.e. mapping encoded integer to the value itself.
+ private var values = new mutable.ArrayBuffer[Any](1024)
+
+ // The dictionary that maps a value to the encoded short integer.
+ private val dictionary = mutable.HashMap.empty[Any, Short]
+
+ // Size of the serialized dictionary in bytes. Initialized to 4 since we need at least an `Int`
+ // to store dictionary element count.
+ private var dictionarySize = 4
+
+ override def gatherCompressibilityStats[T <: NativeType](
+ value: T#JvmType,
+ columnType: ColumnType[T, T#JvmType]) {
+
+ if (!overflow) {
+ val actualSize = columnType.actualSize(value)
+ count += 1
+ _uncompressedSize += actualSize
+
+ if (!dictionary.contains(value)) {
+ if (dictionary.size < MAX_DICT_SIZE) {
+ val clone = columnType.clone(value)
+ values += clone
+ dictionarySize += actualSize
+ dictionary(clone) = dictionary.size.toShort
+ } else {
+ overflow = true
+ values.clear()
+ dictionary.clear()
+ }
+ }
+ }
+ }
+
+ override def compress[T <: NativeType](
+ from: ByteBuffer,
+ to: ByteBuffer,
+ columnType: ColumnType[T, T#JvmType]) = {
+
+ if (overflow) {
+ throw new IllegalStateException(
+ "Dictionary encoding should not be used because of dictionary overflow.")
+ }
+
+ to.putInt(DictionaryEncoding.typeId)
+ .putInt(dictionary.size)
+
+ var i = 0
+ while (i < values.length) {
+ columnType.append(values(i).asInstanceOf[T#JvmType], to)
+ i += 1
+ }
+
+ while (from.hasRemaining) {
+ to.putShort(dictionary(columnType.extract(from)))
+ }
+
+ to.rewind()
+ to
+ }
+
+ override def uncompressedSize = _uncompressedSize
+
+ override def compressedSize = if (overflow) Int.MaxValue else dictionarySize + count * 2
+ }
+
+ class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
+ extends compression.Decoder[T] {
+
+ private val dictionary = {
+ // TODO Can we clean up this mess? Maybe move this to `DataType`?
+ implicit val classTag = {
+ val mirror = runtimeMirror(getClass.getClassLoader)
+ ClassTag[T#JvmType](mirror.runtimeClass(columnType.scalaTag.tpe))
+ }
+
+ Array.fill(buffer.getInt()) {
+ columnType.extract(buffer)
+ }
+ }
+
+ override def next() = dictionary(buffer.getShort())
+
+ override def hasNext = buffer.hasRemaining
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 915f551fb2f01..d8e1b970c1d88 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -32,7 +32,13 @@ class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
kryo.setRegistrationRequired(false)
kryo.register(classOf[MutablePair[_, _]])
kryo.register(classOf[Array[Any]])
+ // This is kinda hacky...
kryo.register(classOf[scala.collection.immutable.Map$Map1], new MapSerializer)
+ kryo.register(classOf[scala.collection.immutable.Map$Map2], new MapSerializer)
+ kryo.register(classOf[scala.collection.immutable.Map$Map3], new MapSerializer)
+ kryo.register(classOf[scala.collection.immutable.Map$Map4], new MapSerializer)
+ kryo.register(classOf[scala.collection.immutable.Map[_,_]], new MapSerializer)
+ kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer)
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index e35ac0b6ca95a..b3e51fdf75270 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -158,10 +158,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case other => other
}
- object TopK extends Strategy {
+ object TakeOrdered extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.StopAfter(IntegerLiteral(limit), logical.Sort(order, child)) =>
- execution.TopK(limit, order, planLater(child))(sparkContext) :: Nil
+ case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) =>
+ execution.TakeOrdered(limit, order, planLater(child))(sparkContext) :: Nil
case _ => Nil
}
}
@@ -213,8 +213,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
sparkContext.parallelize(data.map(r =>
new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row))
execution.ExistingRdd(output, dataAsRdd) :: Nil
- case logical.StopAfter(IntegerLiteral(limit), child) =>
- execution.StopAfter(limit, planLater(child))(sparkContext) :: Nil
+ case logical.Limit(IntegerLiteral(limit), child) =>
+ execution.Limit(limit, planLater(child))(sparkContext) :: Nil
case Unions(unionChildren) =>
execution.Union(unionChildren.map(planLater))(sparkContext) :: Nil
case logical.Generate(generator, join, outer, _, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 65cb8f8becefa..524e5022ee14b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -19,27 +19,28 @@ package org.apache.spark.sql.execution
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.rdd.RDD
-import org.apache.spark.SparkContext
-
+import org.apache.spark.{HashPartitioner, SparkConf, SparkContext}
+import org.apache.spark.rdd.{RDD, ShuffledRDD}
+import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{OrderedDistribution, UnspecifiedDistribution}
-import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.util.MutablePair
+
case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {
- def output = projectList.map(_.toAttribute)
+ override def output = projectList.map(_.toAttribute)
- def execute() = child.execute().mapPartitions { iter =>
+ override def execute() = child.execute().mapPartitions { iter =>
@transient val reusableProjection = new MutableProjection(projectList)
iter.map(reusableProjection)
}
}
case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
- def output = child.output
+ override def output = child.output
- def execute() = child.execute().mapPartitions { iter =>
+ override def execute() = child.execute().mapPartitions { iter =>
iter.filter(condition.apply(_).asInstanceOf[Boolean])
}
}
@@ -47,37 +48,59 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: SparkPlan)
extends UnaryNode {
- def output = child.output
+ override def output = child.output
// TODO: How to pick seed?
- def execute() = child.execute().sample(withReplacement, fraction, seed)
+ override def execute() = child.execute().sample(withReplacement, fraction, seed)
}
case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends SparkPlan {
// TODO: attributes output by union should be distinct for nullability purposes
- def output = children.head.output
- def execute() = sc.union(children.map(_.execute()))
+ override def output = children.head.output
+ override def execute() = sc.union(children.map(_.execute()))
override def otherCopyArgs = sc :: Nil
}
-case class StopAfter(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode {
+/**
+ * Take the first limit elements. Note that the implementation is different depending on whether
+ * this is a terminal operator or not. If it is terminal and is invoked using executeCollect,
+ * this operator uses Spark's take method on the Spark driver. If it is not terminal or is
+ * invoked using execute, we first take the limit on each partition, and then repartition all the
+ * data to a single partition to compute the global limit.
+ */
+case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode {
+ // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
+ // partition local limit -> exchange into one partition -> partition local limit again
+
override def otherCopyArgs = sc :: Nil
- def output = child.output
+ override def output = child.output
override def executeCollect() = child.execute().map(_.copy()).take(limit)
- // TODO: Terminal split should be implemented differently from non-terminal split.
- // TODO: Pick num splits based on |limit|.
- def execute() = sc.makeRDD(executeCollect(), 1)
+ override def execute() = {
+ val rdd = child.execute().mapPartitions { iter =>
+ val mutablePair = new MutablePair[Boolean, Row]()
+ iter.take(limit).map(row => mutablePair.update(false, row))
+ }
+ val part = new HashPartitioner(1)
+ val shuffled = new ShuffledRDD[Boolean, Row, MutablePair[Boolean, Row]](rdd, part)
+ shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
+ shuffled.mapPartitions(_.take(limit).map(_._2))
+ }
}
-case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
- (@transient sc: SparkContext) extends UnaryNode {
+/**
+ * Take the first limit elements as defined by the sortOrder. This is logically equivalent to
+ * having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but
+ * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
+ */
+case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
+ (@transient sc: SparkContext) extends UnaryNode {
override def otherCopyArgs = sc :: Nil
- def output = child.output
+ override def output = child.output
@transient
lazy val ordering = new RowOrdering(sortOrder)
@@ -86,7 +109,7 @@ case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
- def execute() = sc.makeRDD(executeCollect(), 1)
+ override def execute() = sc.makeRDD(executeCollect(), 1)
}
@@ -101,7 +124,7 @@ case class Sort(
@transient
lazy val ordering = new RowOrdering(sortOrder)
- def execute() = attachTree(this, "sort") {
+ override def execute() = attachTree(this, "sort") {
// TODO: Optimize sorting operation?
child.execute()
.mapPartitions(
@@ -109,7 +132,7 @@ case class Sort(
preservesPartitioning = true)
}
- def output = child.output
+ override def output = child.output
}
object ExistingRdd {
@@ -130,6 +153,6 @@ object ExistingRdd {
}
case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
- def execute() = rdd
+ override def execute() = rdd
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
new file mode 100644
index 0000000000000..e5902c3cae381
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalatest.FunSuite
+import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.execution.SparkLogicalPlan
+import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
+
+class CachedTableSuite extends QueryTest {
+ TestData // Load test tables.
+
+ test("read from cached table and uncache") {
+ TestSQLContext.cacheTable("testData")
+
+ checkAnswer(
+ TestSQLContext.table("testData"),
+ testData.collect().toSeq
+ )
+
+ TestSQLContext.table("testData").queryExecution.analyzed match {
+ case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching
+ case noCache => fail(s"No cache node found in plan $noCache")
+ }
+
+ TestSQLContext.uncacheTable("testData")
+
+ checkAnswer(
+ TestSQLContext.table("testData"),
+ testData.collect().toSeq
+ )
+
+ TestSQLContext.table("testData").queryExecution.analyzed match {
+ case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) =>
+ fail(s"Table still cached after uncache: $cachePlan")
+ case noCache => // Table uncached successfully
+ }
+ }
+
+ test("correct error on uncache of non-cached table") {
+ intercept[IllegalArgumentException] {
+ TestSQLContext.uncacheTable("testData")
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
new file mode 100644
index 0000000000000..70033a050c78c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.test.TestSQLContext._
+
+case class ReflectData(
+ stringField: String,
+ intField: Int,
+ longField: Long,
+ floatField: Float,
+ doubleField: Double,
+ shortField: Short,
+ byteField: Byte,
+ booleanField: Boolean,
+ decimalField: BigDecimal,
+ seqInt: Seq[Int])
+
+case class ReflectBinary(data: Array[Byte])
+
+class ScalaReflectionRelationSuite extends FunSuite {
+ test("query case class RDD") {
+ val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
+ BigDecimal(1), Seq(1,2,3))
+ val rdd = sparkContext.parallelize(data :: Nil)
+ rdd.registerAsTable("reflectData")
+
+ assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq)
+ }
+
+ // Equality is broken for Arrays, so we test that separately.
+ test("query binary data") {
+ val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil)
+ rdd.registerAsTable("reflectBinary")
+
+ val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]]
+ assert(result.toSeq === Seq[Byte](1))
+ }
+}
\ No newline at end of file
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
new file mode 100644
index 0000000000000..78640b876d4aa
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.types._
+
+class ColumnStatsSuite extends FunSuite {
+ testColumnStats(classOf[BooleanColumnStats], BOOLEAN)
+ testColumnStats(classOf[ByteColumnStats], BYTE)
+ testColumnStats(classOf[ShortColumnStats], SHORT)
+ testColumnStats(classOf[IntColumnStats], INT)
+ testColumnStats(classOf[LongColumnStats], LONG)
+ testColumnStats(classOf[FloatColumnStats], FLOAT)
+ testColumnStats(classOf[DoubleColumnStats], DOUBLE)
+ testColumnStats(classOf[StringColumnStats], STRING)
+
+ def testColumnStats[T <: NativeType, U <: NativeColumnStats[T]](
+ columnStatsClass: Class[U],
+ columnType: NativeColumnType[T]) {
+
+ val columnStatsName = columnStatsClass.getSimpleName
+
+ test(s"$columnStatsName: empty") {
+ val columnStats = columnStatsClass.newInstance()
+ expectResult(columnStats.initialBounds, "Wrong initial bounds") {
+ (columnStats.lowerBound, columnStats.upperBound)
+ }
+ }
+
+ test(s"$columnStatsName: non-empty") {
+ import ColumnarTestUtils._
+
+ val columnStats = columnStatsClass.newInstance()
+ val rows = Seq.fill(10)(makeRandomRow(columnType))
+ rows.foreach(columnStats.gatherStats(_, 0))
+
+ val values = rows.map(_.head.asInstanceOf[T#JvmType])
+ val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]]
+
+ expectResult(values.min(ordering), "Wrong lower bound")(columnStats.lowerBound)
+ expectResult(values.max(ordering), "Wrong upper bound")(columnStats.upperBound)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 2d431affbcfcc..1d3608ed2d9ab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -19,46 +19,56 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
-import scala.util.Random
-
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.SparkSqlSerializer
class ColumnTypeSuite extends FunSuite {
- val columnTypes = Seq(INT, SHORT, LONG, BYTE, DOUBLE, FLOAT, STRING, BINARY, GENERIC)
+ val DEFAULT_BUFFER_SIZE = 512
test("defaultSize") {
- val defaultSize = Seq(4, 2, 8, 1, 8, 4, 8, 16, 16)
+ val checks = Map(
+ INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4,
+ BOOLEAN -> 1, STRING -> 8, BINARY -> 16, GENERIC -> 16)
- columnTypes.zip(defaultSize).foreach { case (columnType, size) =>
- assert(columnType.defaultSize === size)
+ checks.foreach { case (columnType, expectedSize) =>
+ expectResult(expectedSize, s"Wrong defaultSize for $columnType") {
+ columnType.defaultSize
+ }
}
}
test("actualSize") {
- val expectedSizes = Seq(4, 2, 8, 1, 8, 4, 4 + 5, 4 + 4, 4 + 11)
- val actualSizes = Seq(
- INT.actualSize(Int.MaxValue),
- SHORT.actualSize(Short.MaxValue),
- LONG.actualSize(Long.MaxValue),
- BYTE.actualSize(Byte.MaxValue),
- DOUBLE.actualSize(Double.MaxValue),
- FLOAT.actualSize(Float.MaxValue),
- STRING.actualSize("hello"),
- BINARY.actualSize(new Array[Byte](4)),
- GENERIC.actualSize(SparkSqlSerializer.serialize(Map(1 -> "a"))))
-
- expectedSizes.zip(actualSizes).foreach { case (expected, actual) =>
- assert(expected === actual)
+ def checkActualSize[T <: DataType, JvmType](
+ columnType: ColumnType[T, JvmType],
+ value: JvmType,
+ expected: Int) {
+
+ expectResult(expected, s"Wrong actualSize for $columnType") {
+ columnType.actualSize(value)
+ }
}
+
+ checkActualSize(INT, Int.MaxValue, 4)
+ checkActualSize(SHORT, Short.MaxValue, 2)
+ checkActualSize(LONG, Long.MaxValue, 8)
+ checkActualSize(BYTE, Byte.MaxValue, 1)
+ checkActualSize(DOUBLE, Double.MaxValue, 8)
+ checkActualSize(FLOAT, Float.MaxValue, 4)
+ checkActualSize(BOOLEAN, true, 1)
+ checkActualSize(STRING, "hello", 4 + 5)
+
+ val binary = Array.fill[Byte](4)(0: Byte)
+ checkActualSize(BINARY, binary, 4 + 4)
+
+ val generic = Map(1 -> "a")
+ checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11)
}
- testNumericColumnType[BooleanType.type, Boolean](
+ testNativeColumnType[BooleanType.type](
BOOLEAN,
- Array.fill(4)(Random.nextBoolean()),
- ByteBuffer.allocate(32),
(buffer: ByteBuffer, v: Boolean) => {
buffer.put((if (v) 1 else 0).toByte)
},
@@ -66,105 +76,42 @@ class ColumnTypeSuite extends FunSuite {
buffer.get() == 1
})
- testNumericColumnType[IntegerType.type, Int](
- INT,
- Array.fill(4)(Random.nextInt()),
- ByteBuffer.allocate(32),
- (_: ByteBuffer).putInt(_),
- (_: ByteBuffer).getInt)
-
- testNumericColumnType[ShortType.type, Short](
- SHORT,
- Array.fill(4)(Random.nextInt(Short.MaxValue).asInstanceOf[Short]),
- ByteBuffer.allocate(32),
- (_: ByteBuffer).putShort(_),
- (_: ByteBuffer).getShort)
-
- testNumericColumnType[LongType.type, Long](
- LONG,
- Array.fill(4)(Random.nextLong()),
- ByteBuffer.allocate(64),
- (_: ByteBuffer).putLong(_),
- (_: ByteBuffer).getLong)
-
- testNumericColumnType[ByteType.type, Byte](
- BYTE,
- Array.fill(4)(Random.nextInt(Byte.MaxValue).asInstanceOf[Byte]),
- ByteBuffer.allocate(64),
- (_: ByteBuffer).put(_),
- (_: ByteBuffer).get)
-
- testNumericColumnType[DoubleType.type, Double](
- DOUBLE,
- Array.fill(4)(Random.nextDouble()),
- ByteBuffer.allocate(64),
- (_: ByteBuffer).putDouble(_),
- (_: ByteBuffer).getDouble)
-
- testNumericColumnType[FloatType.type, Float](
- FLOAT,
- Array.fill(4)(Random.nextFloat()),
- ByteBuffer.allocate(64),
- (_: ByteBuffer).putFloat(_),
- (_: ByteBuffer).getFloat)
-
- test("STRING") {
- val buffer = ByteBuffer.allocate(128)
- val seq = Array("hello", "world", "spark", "sql")
-
- seq.map(_.getBytes).foreach { bytes: Array[Byte] =>
- buffer.putInt(bytes.length).put(bytes)
- }
+ testNativeColumnType[IntegerType.type](INT, _.putInt(_), _.getInt)
- buffer.rewind()
- seq.foreach { s =>
- assert(s === STRING.extract(buffer))
- }
+ testNativeColumnType[ShortType.type](SHORT, _.putShort(_), _.getShort)
- buffer.rewind()
- seq.foreach(STRING.append(_, buffer))
+ testNativeColumnType[LongType.type](LONG, _.putLong(_), _.getLong)
- buffer.rewind()
- seq.foreach { s =>
- val length = buffer.getInt
- assert(length === s.getBytes.length)
+ testNativeColumnType[ByteType.type](BYTE, _.put(_), _.get)
+
+ testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble)
+
+ testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat)
+ testNativeColumnType[StringType.type](
+ STRING,
+ (buffer: ByteBuffer, string: String) => {
+ val bytes = string.getBytes()
+ buffer.putInt(bytes.length).put(string.getBytes)
+ },
+ (buffer: ByteBuffer) => {
+ val length = buffer.getInt()
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
- assert(s === new String(bytes))
- }
- }
-
- test("BINARY") {
- val buffer = ByteBuffer.allocate(128)
- val seq = Array.fill(4) {
- val bytes = new Array[Byte](4)
- Random.nextBytes(bytes)
- bytes
- }
+ new String(bytes)
+ })
- seq.foreach { bytes =>
+ testColumnType[BinaryType.type, Array[Byte]](
+ BINARY,
+ (buffer: ByteBuffer, bytes: Array[Byte]) => {
buffer.putInt(bytes.length).put(bytes)
- }
-
- buffer.rewind()
- seq.foreach { b =>
- assert(b === BINARY.extract(buffer))
- }
-
- buffer.rewind()
- seq.foreach(BINARY.append(_, buffer))
-
- buffer.rewind()
- seq.foreach { b =>
- val length = buffer.getInt
- assert(length === b.length)
-
+ },
+ (buffer: ByteBuffer) => {
+ val length = buffer.getInt()
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
- assert(b === bytes)
- }
- }
+ bytes
+ })
test("GENERIC") {
val buffer = ByteBuffer.allocate(512)
@@ -177,43 +124,58 @@ class ColumnTypeSuite extends FunSuite {
val length = buffer.getInt()
assert(length === serializedObj.length)
- val bytes = new Array[Byte](length)
- buffer.get(bytes, 0, length)
- assert(obj === SparkSqlSerializer.deserialize(bytes))
+ expectResult(obj, "Deserialized object didn't equal to the original object") {
+ val bytes = new Array[Byte](length)
+ buffer.get(bytes, 0, length)
+ SparkSqlSerializer.deserialize(bytes)
+ }
buffer.rewind()
buffer.putInt(serializedObj.length).put(serializedObj)
- buffer.rewind()
- assert(obj === SparkSqlSerializer.deserialize(GENERIC.extract(buffer)))
+ expectResult(obj, "Deserialized object didn't equal to the original object") {
+ buffer.rewind()
+ SparkSqlSerializer.deserialize(GENERIC.extract(buffer))
+ }
+ }
+
+ def testNativeColumnType[T <: NativeType](
+ columnType: NativeColumnType[T],
+ putter: (ByteBuffer, T#JvmType) => Unit,
+ getter: (ByteBuffer) => T#JvmType) {
+
+ testColumnType[T, T#JvmType](columnType, putter, getter)
}
- def testNumericColumnType[T <: DataType, JvmType](
+ def testColumnType[T <: DataType, JvmType](
columnType: ColumnType[T, JvmType],
- seq: Seq[JvmType],
- buffer: ByteBuffer,
putter: (ByteBuffer, JvmType) => Unit,
getter: (ByteBuffer) => JvmType) {
- val columnTypeName = columnType.getClass.getSimpleName.stripSuffix("$")
+ val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE)
+ val seq = (0 until 4).map(_ => makeRandomValue(columnType))
- test(s"$columnTypeName.extract") {
+ test(s"$columnType.extract") {
buffer.rewind()
seq.foreach(putter(buffer, _))
buffer.rewind()
- seq.foreach { i =>
- assert(i === columnType.extract(buffer))
+ seq.foreach { expected =>
+ assert(
+ expected === columnType.extract(buffer),
+ "Extracted value didn't equal to the original one")
}
}
- test(s"$columnTypeName.append") {
+ test(s"$columnType.append") {
buffer.rewind()
seq.foreach(columnType.append(_, buffer))
buffer.rewind()
- seq.foreach { i =>
- assert(i === getter(buffer))
+ seq.foreach { expected =>
+ assert(
+ expected === getter(buffer),
+ "Extracted value didn't equal to the original one")
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala
index 928851a385d41..70b2e851737f8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.columnar
+import org.apache.spark.sql.{QueryTest, TestData}
import org.apache.spark.sql.execution.SparkLogicalPlan
import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{TestData, DslQuerySuite}
-class ColumnarQuerySuite extends DslQuerySuite {
+class ColumnarQuerySuite extends QueryTest {
import TestData._
import TestSQLContext._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestData.scala
deleted file mode 100644
index ddcdede8d1a4a..0000000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestData.scala
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.columnar
-
-import scala.util.Random
-
-import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-
-// TODO Enrich test data
-object ColumnarTestData {
- object GenericMutableRow {
- def apply(values: Any*) = {
- val row = new GenericMutableRow(values.length)
- row.indices.foreach { i =>
- row(i) = values(i)
- }
- row
- }
- }
-
- def randomBytes(length: Int) = {
- val bytes = new Array[Byte](length)
- Random.nextBytes(bytes)
- bytes
- }
-
- val nonNullRandomRow = GenericMutableRow(
- Random.nextInt(),
- Random.nextLong(),
- Random.nextFloat(),
- Random.nextDouble(),
- Random.nextBoolean(),
- Random.nextInt(Byte.MaxValue).asInstanceOf[Byte],
- Random.nextInt(Short.MaxValue).asInstanceOf[Short],
- Random.nextString(Random.nextInt(64)),
- randomBytes(Random.nextInt(64)),
- Map(Random.nextInt() -> Random.nextString(4)))
-
- val nullRow = GenericMutableRow(Seq.fill(10)(null): _*)
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
new file mode 100644
index 0000000000000..04bdc43d95328
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar
+
+import scala.collection.immutable.HashSet
+import scala.util.Random
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.catalyst.types.{DataType, NativeType}
+
+object ColumnarTestUtils {
+ def makeNullRow(length: Int) = {
+ val row = new GenericMutableRow(length)
+ (0 until length).foreach(row.setNullAt)
+ row
+ }
+
+ def makeRandomValue[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]): JvmType = {
+ def randomBytes(length: Int) = {
+ val bytes = new Array[Byte](length)
+ Random.nextBytes(bytes)
+ bytes
+ }
+
+ (columnType match {
+ case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
+ case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
+ case INT => Random.nextInt()
+ case LONG => Random.nextLong()
+ case FLOAT => Random.nextFloat()
+ case DOUBLE => Random.nextDouble()
+ case STRING => Random.nextString(Random.nextInt(32))
+ case BOOLEAN => Random.nextBoolean()
+ case BINARY => randomBytes(Random.nextInt(32))
+ case _ =>
+ // Using a random one-element map instead of an arbitrary object
+ Map(Random.nextInt() -> Random.nextString(Random.nextInt(32)))
+ }).asInstanceOf[JvmType]
+ }
+
+ def makeRandomValues(
+ head: ColumnType[_ <: DataType, _],
+ tail: ColumnType[_ <: DataType, _]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail)
+
+ def makeRandomValues(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Seq[Any] = {
+ columnTypes.map(makeRandomValue(_))
+ }
+
+ def makeUniqueRandomValues[T <: DataType, JvmType](
+ columnType: ColumnType[T, JvmType],
+ count: Int): Seq[JvmType] = {
+
+ Iterator.iterate(HashSet.empty[JvmType]) { set =>
+ set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next()
+ }.drop(count).next().toSeq
+ }
+
+ def makeRandomRow(
+ head: ColumnType[_ <: DataType, _],
+ tail: ColumnType[_ <: DataType, _]*): Row = makeRandomRow(Seq(head) ++ tail)
+
+ def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Row = {
+ val row = new GenericMutableRow(columnTypes.length)
+ makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) =>
+ row(index) = value
+ }
+ row
+ }
+
+ def makeUniqueValuesAndSingleValueRows[T <: NativeType](
+ columnType: NativeColumnType[T],
+ count: Int) = {
+
+ val values = makeUniqueRandomValues(columnType, count)
+ val rows = values.map { value =>
+ val row = new GenericMutableRow(1)
+ row(0) = value
+ row
+ }
+
+ (values, rows)
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
index d413d483f4e7e..4a21eb6201a69 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
@@ -17,12 +17,29 @@
package org.apache.spark.sql.columnar
+import java.nio.ByteBuffer
+
import org.scalatest.FunSuite
-import org.apache.spark.sql.catalyst.types.DataType
+
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.catalyst.types.DataType
+
+class TestNullableColumnAccessor[T <: DataType, JvmType](
+ buffer: ByteBuffer,
+ columnType: ColumnType[T, JvmType])
+ extends BasicColumnAccessor(buffer, columnType)
+ with NullableColumnAccessor
+
+object TestNullableColumnAccessor {
+ def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) = {
+ // Skips the column type ID
+ buffer.getInt()
+ new TestNullableColumnAccessor(buffer, columnType)
+ }
+}
class NullableColumnAccessorSuite extends FunSuite {
- import ColumnarTestData._
+ import ColumnarTestUtils._
Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach {
testNullableColumnAccessor(_)
@@ -30,30 +47,32 @@ class NullableColumnAccessorSuite extends FunSuite {
def testNullableColumnAccessor[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) {
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+ val nullRow = makeNullRow(1)
- test(s"$typeName accessor: empty column") {
- val builder = ColumnBuilder(columnType.typeId, 4)
- val accessor = ColumnAccessor(builder.build())
+ test(s"Nullable $typeName column accessor: empty column") {
+ val builder = TestNullableColumnBuilder(columnType)
+ val accessor = TestNullableColumnAccessor(builder.build(), columnType)
assert(!accessor.hasNext)
}
- test(s"$typeName accessor: access null values") {
- val builder = ColumnBuilder(columnType.typeId, 4)
+ test(s"Nullable $typeName column accessor: access null values") {
+ val builder = TestNullableColumnBuilder(columnType)
+ val randomRow = makeRandomRow(columnType)
(0 until 4).foreach { _ =>
- builder.appendFrom(nonNullRandomRow, columnType.typeId)
- builder.appendFrom(nullRow, columnType.typeId)
+ builder.appendFrom(randomRow, 0)
+ builder.appendFrom(nullRow, 0)
}
- val accessor = ColumnAccessor(builder.build())
+ val accessor = TestNullableColumnAccessor(builder.build(), columnType)
val row = new GenericMutableRow(1)
(0 until 4).foreach { _ =>
accessor.extractTo(row, 0)
- assert(row(0) === nonNullRandomRow(columnType.typeId))
+ assert(row(0) === randomRow(0))
accessor.extractTo(row, 0)
- assert(row(0) === null)
+ assert(row.isNullAt(0))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index 5222a47e1ab87..d9d1e1bfddb75 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -19,63 +19,71 @@ package org.apache.spark.sql.columnar
import org.scalatest.FunSuite
-import org.apache.spark.sql.catalyst.types.DataType
+import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.execution.SparkSqlSerializer
+class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType])
+ extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType)
+ with NullableColumnBuilder
+
+object TestNullableColumnBuilder {
+ def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) = {
+ val builder = new TestNullableColumnBuilder(columnType)
+ builder.initialize(initialSize)
+ builder
+ }
+}
+
class NullableColumnBuilderSuite extends FunSuite {
- import ColumnarTestData._
+ import ColumnarTestUtils._
Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach {
testNullableColumnBuilder(_)
}
def testNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) {
- val columnBuilder = ColumnBuilder(columnType.typeId)
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
test(s"$typeName column builder: empty column") {
- columnBuilder.initialize(4)
-
+ val columnBuilder = TestNullableColumnBuilder(columnType)
val buffer = columnBuilder.build()
- // For column type ID
- assert(buffer.getInt() === columnType.typeId)
- // For null count
- assert(buffer.getInt === 0)
+ expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt())
+ expectResult(0, "Wrong null count")(buffer.getInt())
assert(!buffer.hasRemaining)
}
test(s"$typeName column builder: buffer size auto growth") {
- columnBuilder.initialize(4)
+ val columnBuilder = TestNullableColumnBuilder(columnType)
+ val randomRow = makeRandomRow(columnType)
- (0 until 4) foreach { _ =>
- columnBuilder.appendFrom(nonNullRandomRow, columnType.typeId)
+ (0 until 4).foreach { _ =>
+ columnBuilder.appendFrom(randomRow, 0)
}
val buffer = columnBuilder.build()
- // For column type ID
- assert(buffer.getInt() === columnType.typeId)
- // For null count
- assert(buffer.getInt() === 0)
+ expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt())
+ expectResult(0, "Wrong null count")(buffer.getInt())
}
test(s"$typeName column builder: null values") {
- columnBuilder.initialize(4)
+ val columnBuilder = TestNullableColumnBuilder(columnType)
+ val randomRow = makeRandomRow(columnType)
+ val nullRow = makeNullRow(1)
- (0 until 4) foreach { _ =>
- columnBuilder.appendFrom(nonNullRandomRow, columnType.typeId)
- columnBuilder.appendFrom(nullRow, columnType.typeId)
+ (0 until 4).foreach { _ =>
+ columnBuilder.appendFrom(randomRow, 0)
+ columnBuilder.appendFrom(nullRow, 0)
}
val buffer = columnBuilder.build()
- // For column type ID
- assert(buffer.getInt() === columnType.typeId)
- // For null count
- assert(buffer.getInt() === 4)
+ expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt())
+ expectResult(4, "Wrong null count")(buffer.getInt())
+
// For null positions
- (1 to 7 by 2).foreach(i => assert(buffer.getInt() === i))
+ (1 to 7 by 2).foreach(expectResult(_, "Wrong null position")(buffer.getInt()))
// For non-null values
(0 until 4).foreach { _ =>
@@ -84,7 +92,8 @@ class NullableColumnBuilderSuite extends FunSuite {
} else {
columnType.extract(buffer)
}
- assert(actual === nonNullRandomRow(columnType.typeId))
+
+ assert(actual === randomRow(0), "Extracted value didn't equal to the original one")
}
assert(!buffer.hasRemaining)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
new file mode 100644
index 0000000000000..184691ab5b46a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import java.nio.ByteBuffer
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar._
+import org.apache.spark.sql.columnar.ColumnarTestUtils._
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+
+class DictionaryEncodingSuite extends FunSuite {
+ testDictionaryEncoding(new IntColumnStats, INT)
+ testDictionaryEncoding(new LongColumnStats, LONG)
+ testDictionaryEncoding(new StringColumnStats, STRING)
+
+ def testDictionaryEncoding[T <: NativeType](
+ columnStats: NativeColumnStats[T],
+ columnType: NativeColumnType[T]) {
+
+ val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+
+ def buildDictionary(buffer: ByteBuffer) = {
+ (0 until buffer.getInt()).map(columnType.extract(buffer) -> _.toShort).toMap
+ }
+
+ test(s"$DictionaryEncoding with $typeName: simple case") {
+ // -------------
+ // Tests encoder
+ // -------------
+
+ val builder = TestCompressibleColumnBuilder(columnStats, columnType, DictionaryEncoding)
+ val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, 2)
+
+ builder.initialize(0)
+ builder.appendFrom(rows(0), 0)
+ builder.appendFrom(rows(1), 0)
+ builder.appendFrom(rows(0), 0)
+ builder.appendFrom(rows(1), 0)
+
+ val buffer = builder.build()
+ val headerSize = CompressionScheme.columnHeaderSize(buffer)
+ // 4 extra bytes for dictionary size
+ val dictionarySize = 4 + values.map(columnType.actualSize).sum
+ // 4 `Short`s, 2 bytes each
+ val compressedSize = dictionarySize + 2 * 4
+ // 4 extra bytes for compression scheme type ID
+ expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity)
+
+ // Skips column header
+ buffer.position(headerSize)
+ expectResult(DictionaryEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt())
+
+ val dictionary = buildDictionary(buffer)
+ Array[Short](0, 1).foreach { i =>
+ expectResult(i, "Wrong dictionary entry")(dictionary(values(i)))
+ }
+
+ Array[Short](0, 1, 0, 1).foreach {
+ expectResult(_, "Wrong column element value")(buffer.getShort())
+ }
+
+ // -------------
+ // Tests decoder
+ // -------------
+
+ // Rewinds, skips column header and 4 more bytes for compression scheme ID
+ buffer.rewind().position(headerSize + 4)
+
+ val decoder = new DictionaryEncoding.Decoder[T](buffer, columnType)
+
+ Array[Short](0, 1, 0, 1).foreach { i =>
+ expectResult(values(i), "Wrong decoded value")(decoder.next())
+ }
+
+ assert(!decoder.hasNext)
+ }
+ }
+
+ test(s"$DictionaryEncoding: overflow") {
+ val builder = TestCompressibleColumnBuilder(new IntColumnStats, INT, DictionaryEncoding)
+ builder.initialize(0)
+
+ (0 to Short.MaxValue).foreach { n =>
+ val row = new GenericMutableRow(1)
+ row.setInt(0, n)
+ builder.appendFrom(row, 0)
+ }
+
+ withClue("Dictionary overflowed, encoding should fail") {
+ intercept[Throwable] {
+ builder.build()
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
new file mode 100644
index 0000000000000..2089ad120d4f2
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar._
+import org.apache.spark.sql.columnar.ColumnarTestUtils._
+
+class RunLengthEncodingSuite extends FunSuite {
+ testRunLengthEncoding(new BooleanColumnStats, BOOLEAN)
+ testRunLengthEncoding(new ByteColumnStats, BYTE)
+ testRunLengthEncoding(new ShortColumnStats, SHORT)
+ testRunLengthEncoding(new IntColumnStats, INT)
+ testRunLengthEncoding(new LongColumnStats, LONG)
+ testRunLengthEncoding(new StringColumnStats, STRING)
+
+ def testRunLengthEncoding[T <: NativeType](
+ columnStats: NativeColumnStats[T],
+ columnType: NativeColumnType[T]) {
+
+ val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+
+ test(s"$RunLengthEncoding with $typeName: simple case") {
+ // -------------
+ // Tests encoder
+ // -------------
+
+ val builder = TestCompressibleColumnBuilder(columnStats, columnType, RunLengthEncoding)
+ val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, 2)
+
+ builder.initialize(0)
+ builder.appendFrom(rows(0), 0)
+ builder.appendFrom(rows(0), 0)
+ builder.appendFrom(rows(1), 0)
+ builder.appendFrom(rows(1), 0)
+
+ val buffer = builder.build()
+ val headerSize = CompressionScheme.columnHeaderSize(buffer)
+ // 4 extra bytes each run for run length
+ val compressedSize = values.map(columnType.actualSize(_) + 4).sum
+ // 4 extra bytes for compression scheme type ID
+ expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity)
+
+ // Skips column header
+ buffer.position(headerSize)
+ expectResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt())
+
+ Array(0, 1).foreach { i =>
+ expectResult(values(i), "Wrong column element value")(columnType.extract(buffer))
+ expectResult(2, "Wrong run length")(buffer.getInt())
+ }
+
+ // -------------
+ // Tests decoder
+ // -------------
+
+ // Rewinds, skips column header and 4 more bytes for compression scheme ID
+ buffer.rewind().position(headerSize + 4)
+
+ val decoder = new RunLengthEncoding.Decoder[T](buffer, columnType)
+
+ Array(0, 0, 1, 1).foreach { i =>
+ expectResult(values(i), "Wrong decoded value")(decoder.next())
+ }
+
+ assert(!decoder.hasNext)
+ }
+
+ test(s"$RunLengthEncoding with $typeName: run length == 1") {
+ // -------------
+ // Tests encoder
+ // -------------
+
+ val builder = TestCompressibleColumnBuilder(columnStats, columnType, RunLengthEncoding)
+ val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, 2)
+
+ builder.initialize(0)
+ builder.appendFrom(rows(0), 0)
+ builder.appendFrom(rows(1), 0)
+
+ val buffer = builder.build()
+ val headerSize = CompressionScheme.columnHeaderSize(buffer)
+ // 4 bytes each run for run length
+ val compressedSize = values.map(columnType.actualSize(_) + 4).sum
+ // 4 bytes for compression scheme type ID
+ expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity)
+
+ // Skips column header
+ buffer.position(headerSize)
+ expectResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt())
+
+ Array(0, 1).foreach { i =>
+ expectResult(values(i), "Wrong column element value")(columnType.extract(buffer))
+ expectResult(1, "Wrong run length")(buffer.getInt())
+ }
+
+ // -------------
+ // Tests decoder
+ // -------------
+
+ // Rewinds, skips column header and 4 more bytes for compression scheme ID
+ buffer.rewind().position(headerSize + 4)
+
+ val decoder = new RunLengthEncoding.Decoder[T](buffer, columnType)
+
+ Array(0, 1).foreach { i =>
+ expectResult(values(i), "Wrong decoded value")(decoder.next())
+ }
+
+ assert(!decoder.hasNext)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
new file mode 100644
index 0000000000000..e0ec812863dcf
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar._
+
+class TestCompressibleColumnBuilder[T <: NativeType](
+ override val columnStats: NativeColumnStats[T],
+ override val columnType: NativeColumnType[T],
+ override val schemes: Seq[CompressionScheme])
+ extends NativeColumnBuilder(columnStats, columnType)
+ with NullableColumnBuilder
+ with CompressibleColumnBuilder[T] {
+
+ override protected def isWorthCompressing(encoder: Encoder) = true
+}
+
+object TestCompressibleColumnBuilder {
+ def apply[T <: NativeType](
+ columnStats: NativeColumnStats[T],
+ columnType: NativeColumnType[T],
+ scheme: CompressionScheme) = {
+
+ new TestCompressibleColumnBuilder(columnStats, columnType, Seq(scheme))
+ }
+}
+
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 197b557cba5f4..46febbfad037d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -188,7 +188,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
val hiveContext = self
override val strategies: Seq[Strategy] = Seq(
- TopK,
+ TakeOrdered,
ParquetOperations,
HiveTableScans,
DataSinks,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 4f8353666a12b..29834a11f41dc 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -141,6 +141,13 @@ class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging {
*/
override def registerTable(
databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = ???
+
+ /**
+ * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore.
+ * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]].
+ */
+ override def unregisterTable(
+ databaseName: Option[String], tableName: String): Unit = ???
}
object HiveMetastoreTypes extends RegexParsers {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 490a592a588d0..b2b03bc790fcc 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -529,7 +529,7 @@ object HiveQl {
val withLimit =
limitClause.map(l => nodeToExpr(l.getChildren.head))
- .map(StopAfter(_, withSort))
+ .map(Limit(_, withSort))
.getOrElse(withSort)
// TOK_INSERT_INTO means to add files to the table.
@@ -602,7 +602,7 @@ object HiveQl {
case Token("TOK_TABLESPLITSAMPLE",
Token("TOK_ROWCOUNT", Nil) ::
Token(count, Nil) :: Nil) =>
- StopAfter(Literal(count.toInt), relation)
+ Limit(Literal(count.toInt), relation)
case Token("TOK_TABLESPLITSAMPLE",
Token("TOK_PERCENT", Nil) ::
Token(fraction, Nil) :: Nil) =>
diff --git a/sql/hive/src/test/resources/golden/alias.*-0-7bdb861d11e895aaea545810cdac316d b/sql/hive/src/test/resources/golden/alias.*-0-7bdb861d11e895aaea545810cdac316d
deleted file mode 100644
index 5f4de85940513..0000000000000
--- a/sql/hive/src/test/resources/golden/alias.*-0-7bdb861d11e895aaea545810cdac316d
+++ /dev/null
@@ -1 +0,0 @@
-0 val_0
\ No newline at end of file
diff --git a/sql/hive/src/test/resources/golden/alias.star-0-7bdb861d11e895aaea545810cdac316d b/sql/hive/src/test/resources/golden/alias.star-0-7bdb861d11e895aaea545810cdac316d
new file mode 100644
index 0000000000000..016f64cc26f2a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/alias.star-0-7bdb861d11e895aaea545810cdac316d
@@ -0,0 +1 @@
+0 val_0
diff --git a/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b b/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b
new file mode 100644
index 0000000000000..60878ffb77064
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b
@@ -0,0 +1 @@
+238 val_238
diff --git a/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b b/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b
new file mode 100644
index 0000000000000..60878ffb77064
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b
@@ -0,0 +1 @@
+238 val_238
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
new file mode 100644
index 0000000000000..68d45e53cdf26
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.spark.sql.execution.SparkLogicalPlan
+import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
+import org.apache.spark.sql.hive.execution.HiveComparisonTest
+
+class CachedTableSuite extends HiveComparisonTest {
+ TestHive.loadTestTable("src")
+
+ test("cache table") {
+ TestHive.cacheTable("src")
+ }
+
+ createQueryTest("read from cached table",
+ "SELECT * FROM src LIMIT 1")
+
+ test("check that table is cached and uncache") {
+ TestHive.table("src").queryExecution.analyzed match {
+ case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching
+ case noCache => fail(s"No cache node found in plan $noCache")
+ }
+ TestHive.uncacheTable("src")
+ }
+
+ createQueryTest("read from uncached table",
+ "SELECT * FROM src LIMIT 1")
+
+ test("make sure table is uncached") {
+ TestHive.table("src").queryExecution.analyzed match {
+ case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) =>
+ fail(s"Table still cached after uncache: $cachePlan")
+ case noCache => // Table uncached successfully
+ }
+ }
+
+ test("correct error on uncache of non-cached table") {
+ intercept[IllegalArgumentException] {
+ TestHive.uncacheTable("src")
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
index d77900ddc950c..40c4e23f90fb8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
@@ -48,7 +48,7 @@ class HiveResolutionSuite extends HiveComparisonTest {
createQueryTest("attr",
"SELECT key FROM src a ORDER BY key LIMIT 1")
- createQueryTest("alias.*",
+ createQueryTest("alias.star",
"SELECT a.* FROM src a ORDER BY key LIMIT 1")
test("case insensitivity with scala reflection") {