From ee9c432a1bde3cb4e8be090c6fd1bae3e41e9e98 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 29 Feb 2016 02:52:43 +0800 Subject: [PATCH] Migrates basic inspection and typed relational DF operations to DS --- .../sql/catalyst/analysis/Analyzer.scala | 8 +- .../scala/org/apache/spark/sql/Dataset.scala | 210 +++++++++++++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 4 +- .../org/apache/spark/sql/DatasetSuite.scala | 199 +++++++++++++++++ 4 files changed, 410 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 23e4709bbd882..c64f7ff6e5ad2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.lang.reflect.Modifier + import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException @@ -559,7 +561,11 @@ class Analyzer( } resolveExpression(unbound, LocalRelation(attributes), throws = true) transform { - case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => + case n: NewInstance + if n.outerPointer.isEmpty && + n.cls.isMemberClass && + !Modifier.isStatic(n.cls.getModifiers) => + n.cls.getEnclosingClass val outer = OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName) if (outer == null) { throw new AnalysisException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 233229150ece9..2d7de78f09dd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedAlias} import org.apache.spark.sql.catalyst.encoders._ @@ -184,13 +184,6 @@ class Dataset[T] private[sql]( new Dataset(sqlContext, queryExecution, encoderFor[U]) } - /** - * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have - * the same name after two Datasets have been joined. - * @since 1.6.0 - */ - def as(alias: String): Dataset[T] = withPlan(SubqueryAlias(alias, _)) - /** * Converts this strongly typed collection of data to generic Dataframe. In contrast to the * strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]] @@ -320,6 +313,32 @@ class Dataset[T] private[sql]( Repartition(numPartitions, shuffle = true, _) } + /** + * Returns a new [[Dataset]] partitioned by the given partitioning expressions into + * `numPartitions`. The resulting DataFrame is hash partitioned. + * + * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). + * + * @since 2.0.0 + */ + @scala.annotation.varargs + def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withPlan { + RepartitionByExpression(partitionExprs.map(_.expr), _, Some(numPartitions)) + } + + /** + * Returns a new [[Dataset]] partitioned by the given partitioning expressions preserving + * the existing number of partitions. The resulting DataFrame is hash partitioned. + * + * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). + * + * @since 2.0.0 + */ + @scala.annotation.varargs + def repartition(partitionExprs: Column*): Dataset[T] = withPlan { + RepartitionByExpression(partitionExprs.map(_.expr), _, numPartitions = None) + } + /** * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. @@ -522,6 +541,13 @@ class Dataset[T] private[sql]( * Typed Relational * * ****************** */ + /** + * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have + * the same name after two Datasets have been joined. + * @since 1.6.0 + */ + def as(alias: String): Dataset[T] = withPlan(SubqueryAlias(alias, _)) + /** * Returns a new [[DataFrame]] by selecting a set of column based expressions. * {{{ @@ -621,6 +647,162 @@ class Dataset[T] private[sql]( sample(withReplacement, fraction, Utils.random.nextLong) } + /** + * Filters rows using the given condition. + * {{{ + * // The following are equivalent: + * peopleDs.filter($"age" > 15) + * peopleDs.where($"age" > 15) + * }}} + * @since 2.0.0 + */ + def filter(condition: Column): Dataset[T] = withPlan(Filter(condition.expr, _)) + + /** + * Filters rows using the given SQL expression. + * {{{ + * peopleDs.filter("age > 15") + * }}} + * @since 2.0.0 + */ + def filter(conditionExpr: String): Dataset[T] = { + filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr))) + } + + /** + * Filters rows using the given condition. This is an alias for `filter`. + * {{{ + * // The following are equivalent: + * peopleDs.filter($"age" > 15) + * peopleDs.where($"age" > 15) + * }}} + * @since 2.0.0 + */ + def where(condition: Column): Dataset[T] = filter(condition) + + /** + * Filters rows using the given SQL expression. + * {{{ + * peopleDs.where("age > 15") + * }}} + * @since 2.0.0 + */ + def where(conditionExpr: String): Dataset[T] = filter(conditionExpr) + + /** + * Returns a new [[Dataset]] by taking the first `n` rows. The difference between this function + * and `head` is that `head` returns an array while `limit` returns a new [[Dataset]]. + * @since 2.0.0 + */ + def limit(n: Int): Dataset[T] = withPlan(Limit(Literal(n), _)) + + /** + * Returns a new [[Dataset]] with each partition sorted by the given expressions. + * + * This is the same operation as "SORT BY" in SQL (Hive QL). + * + * @since 2.0.0 + */ + @scala.annotation.varargs + def sortWithinPartitions(sortCol: String, sortCols: String*): Dataset[T] = { + sortWithinPartitions((sortCol +: sortCols).map(Column(_)) : _*) + } + + /** + * Returns a new [[Dataset]] with each partition sorted by the given expressions. + * + * This is the same operation as "SORT BY" in SQL (Hive QL). + * + * @since 2.0.0 + */ + @scala.annotation.varargs + def sortWithinPartitions(sortExprs: Column*): Dataset[T] = { + sortInternal(global = false, sortExprs) + } + + /** + * Returns a new [[Dataset]] sorted by the specified column, all in ascending order. + * {{{ + * // The following 3 are equivalent + * ds.sort("sortcol") + * ds.sort($"sortcol") + * ds.sort($"sortcol".asc) + * }}} + * @since 2.0.0 + */ + @scala.annotation.varargs + def sort(sortCol: String, sortCols: String*): Dataset[T] = { + sort((sortCol +: sortCols).map(apply) : _*) + } + + /** + * Returns a new [[Dataset]] sorted by the given expressions. For example: + * {{{ + * ds.sort($"col1", $"col2".desc) + * }}} + * @since 2.0.0 + */ + @scala.annotation.varargs + def sort(sortExprs: Column*): Dataset[T] = { + sortInternal(global = true, sortExprs) + } + + /** + * Returns a new [[Dataset]] sorted by the given expressions. + * This is an alias of the `sort` function. + * @since 2.0.0 + */ + @scala.annotation.varargs + def orderBy(sortCol: String, sortCols: String*): Dataset[T] = sort(sortCol, sortCols : _*) + + /** + * Returns a new [[Dataset]] sorted by the given expressions. + * This is an alias of the `sort` function. + * @since 2.0.0 + */ + @scala.annotation.varargs + def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs : _*) + + /** + * Randomly splits this [[Dataset]] with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @param seed Seed for sampling. + * @since 2.0.0 + */ + def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = { + // It is possible that the underlying Dataset doesn't guarantee the ordering of rows in its + // constituent partitions each time a split is materialized which could result in + // overlapping splits. To prevent this, we explicitly sort each input partition to make the + // ordering deterministic. + val sorted = Sort(logicalPlan.output.map(SortOrder(_, Ascending)), global = false, logicalPlan) + val sum = weights.sum + val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) + normalizedCumWeights.sliding(2).map { x => + new Dataset(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)()) + }.toArray + } + + /** + * Randomly splits this [[Dataset]] with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @since 2.0.0 + */ + def randomSplit(weights: Array[Double]): Array[Dataset[T]] = { + randomSplit(weights, Utils.random.nextLong) + } + + /** + * Randomly splits this [[Dataset]] with the provided weights. Provided for the Python Api. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @param seed Seed for sampling. + */ + private[spark] def randomSplit(weights: List[Double], seed: Long): Array[Dataset[T]] = { + randomSplit(weights.toArray, seed) + } + /* **************** * * Set operations * * **************** */ @@ -845,4 +1027,16 @@ class Dataset[T] private[sql]( s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") } } + + private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { + val sortOrder: Seq[SortOrder] = sortExprs.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + withPlan(Sort(sortOrder, global = global, _)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 84f30c0aaf862..2a0135c2ea502 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1154,11 +1154,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(original.rdd.partitions.length == 1) val df = original.repartition(5, $"key") assert(df.rdd.partitions.length == 5) - checkAnswer(original.select(), df.select()) + checkAnswer(original, df) val df2 = original.repartition(10, $"key") assert(df2.rdd.partitions.length == 10) - checkAnswer(original.select(), df2.select()) + checkAnswer(original, df2) // Group by the column we are distributed by. This should generate a plan with no exchange // between the aggregates diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index fca4b9c3eac2a..f59810c5388b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -23,8 +23,11 @@ import java.sql.{Date, Timestamp} import scala.language.postfixOps import org.apache.spark.sql.catalyst.encoders.OuterScopes +import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData.{ComplexData, TestData, TestData2, UpperCaseData} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} case class OtherTuple(_1: String, _2: Int) @@ -32,6 +35,38 @@ case class OtherTuple(_1: String, _2: Int) class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ + /** + * Verifies that there is no Exchange between the Aggregations for `df` + */ + private def verifyNonExchangingAgg[T: Encoder](ds: Dataset[T]) = { + var atFirstAgg: Boolean = false + ds.queryExecution.executedPlan.foreach { + case agg: TungstenAggregate => + atFirstAgg = !atFirstAgg + case _ => + if (atFirstAgg) { + fail("Should not have operators between the two aggregations") + } + } + } + + /** + * Verifies that there is an Exchange between the Aggregations for `df` + */ + private def verifyExchangingAgg[T: Encoder](ds: Dataset[T]) = { + var atFirstAgg: Boolean = false + ds.queryExecution.executedPlan.foreach { + case agg: TungstenAggregate => { + if (atFirstAgg) { + fail("Should not have back to back Aggregates") + } + atFirstAgg = true + } + case e: ShuffleExchange => atFirstAgg = false + case _ => + } + } + test("toDS") { val data = Seq(("a", 1), ("b", 2), ("c", 3)) checkAnswer( @@ -630,6 +665,170 @@ class DatasetSuite extends QueryTest with SharedSQLContext { df.col("_1") df.col("t.`_1`") } + + protected lazy val complexDataDS: Dataset[ComplexData] = + sqlContext.sparkContext.parallelize( + ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: + ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: + Nil).toDS() + + protected lazy val upperCaseDataDS: Dataset[UpperCaseData] = + sqlContext.sparkContext.parallelize( + UpperCaseData(1, "A") :: + UpperCaseData(2, "B") :: + UpperCaseData(3, "C") :: + UpperCaseData(4, "D") :: + UpperCaseData(5, "E") :: + UpperCaseData(6, "F") :: Nil).toDS() + + protected lazy val testDataDS: Dataset[TestData] = + sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDS() + + test("df-to-ds: access complex data") { + assert(complexDataDS.filter(complexDataDS("a").getItem(0) === 2).count() == 1) + assert(complexDataDS.filter(complexDataDS("m").getItem("1") === 1).count() == 1) + assert(complexDataDS.filter(complexDataDS("s").getField("key") === 1).count() == 1) + } + + test("df-to-ds: SPARK-7133: Implement struct, array, and map field accessor") { + assert(complexDataDS.filter(complexDataDS("a")(0) === 2).count() == 1) + assert(complexDataDS.filter(complexDataDS("m")("1") === 1).count() == 1) + assert(complexDataDS.filter(complexDataDS("s")("key") === 1).count() == 1) + assert(complexDataDS.filter(complexDataDS("m")(complexDataDS("s")("value")) === 1).count() == 1) + assert(complexDataDS.filter(complexDataDS("a")(complexDataDS("s")("key")) === 1).count() == 1) + } + + test("df-to-ds: Sorting columns are not in Filter and Project") { + checkAnswer( + upperCaseDataDS.filter('N > 1).select('N.as[Int]).filter('N < 6).orderBy('L.asc), + 2, 3, 4, 5) + } + + test("df-to-ds: filterExpr") { + val res = testDataDS.collect().filter(_.key > 90).toSeq + checkAnswer(testDataDS.filter("key > 90"), res: _*) + checkAnswer(testDataDS.filter("key > 9.0e1"), res: _*) + checkAnswer(testDataDS.filter("key > .9e+2"), res: _*) + checkAnswer(testDataDS.filter("key > 0.9e+2"), res: _*) + checkAnswer(testDataDS.filter("key > 900e-1"), res: _*) + checkAnswer(testDataDS.filter("key > 900.0E-1"), res: _*) + checkAnswer(testDataDS.filter("key > 9.e+1"), res: _*) + } + + test("df-to-ds: filterExpr using where") { + checkAnswer( + testDataDS.where("key > 50"), + testDataDS.collect().filter(_.key > 50).toSeq: _*) + } + + test("df-to-ds: limit") { + checkAnswer( + testDataDS.limit(10), + testDataDS.take(10).toSeq: _*) + + checkAnswer( + arrayData.toDS().limit(1), + arrayData.take(1): _*) + + checkAnswer( + mapData.toDS().limit(1), + mapData.take(1): _*) + + // SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake + checkAnswer( + (0 until 2).toDS().limit(2147483638), + 0, 1 + ) + } + + test("df-to-ds: repartition by key") { + val ds1 = testDataDS.repartition(1) + assert(ds1.rdd.partitions.length == 1) + + val ds2 = ds1.repartition(5, $"key") + assert(ds2.rdd.partitions.length == 5) + + checkAnswer(ds2, ds1.collect(): _*) + } + + test("df-to-ds: group by after distributed by") { + // Group by the column we are distributed by. This should generate a plan with no exchange + // between the aggregates + verifyNonExchangingAgg( + testDataDS + .repartition($"key") + .groupBy(_.key) + .count()) + + verifyNonExchangingAgg( + testDataDS + .repartition($"key", $"value") + .groupBy(data => data.key -> data.value) + .count()) + + // Grouping by just the first distributeBy expr, need to exchange. + verifyExchangingAgg( + testDataDS + .repartition($"key", $"value") + .groupBy(_.key) + .count()) + } + + test("df-to-ds: distribute and order by") { + val ds1 = (1 to 100).map(i => TestData2(i % 10, i)).toDS() + val ds2 = ds1.repartition(2, $"a").sortWithinPartitions($"b".desc) + + assert(ds2.rdd.partitions.length == 2) + + ds2.rdd.collectPartitions().foreach { data => + assert(data.sliding(2).forall { + case Array(TestData2(_, b1), TestData2(_, b2)) => b1 > b2 + }, "Partition is not ordered") + } + + assert(ds2.collect().sliding(2).exists { + case Array(TestData2(_, b1), TestData2(_, b2)) => b1 < b2 + }, "Dataset should not be globally ordered") + } + + test("df-to-ds: distribute and order by with multiple sort orders") { + // For comparing tuples + import scala.math.Ordering.Implicits._ + + val ds1 = (1 to 100).map(i => TestData2(i % 10, i)).toDS() + val ds2 = ds1.repartition(2, $"a").sortWithinPartitions($"b".desc, $"a".asc) + + assert(ds2.rdd.partitions.length == 2) + + ds2.rdd.collectPartitions().foreach { data => + assert(data.sliding(2).forall { + case Array(TestData2(a1, b1), TestData2(a2, b2)) => + (b1, a1) > (b2, a2) + }, "Partition is not ordered") + } + + assert(ds2.collect().sliding(2).exists { + case Array(TestData2(a1, b1), TestData2(a2, b2)) => (b1, a1) < (b2, a2) + }, "Dataset should not be globally ordered") + } + + test("df-to-ds: distribute into one partition and order by") { + val ds1 = (1 to 100).map(i => TestData2(i % 10, i)).toDS() + val ds2 = ds1.repartition(1, $"a").sortWithinPartitions("b") + + assert(ds2.rdd.partitions.length == 1) + + ds2.rdd.collectPartitions().foreach { data => + assert(data.sliding(2).forall { + case Array(TestData2(a1, b1), TestData2(a2, b2)) => b1 < b2 + }, "Partition is not ordered") + } + + assert(ds2.collect().sliding(2).forall { + case Array(TestData2(a1, b1), TestData2(a2, b2)) => b1 < b2 + }, "Dataset should be globally ordered") + } } class OuterClass extends Serializable {