From f9056ac510868cb3dc878b34e9259155c7aa9d88 Mon Sep 17 00:00:00 2001 From: azagrebin Date: Wed, 18 Mar 2015 10:47:20 +0100 Subject: [PATCH] [SPARK-6117] [SQL] create one aggregation and split it locally into resulting DF, colocate test data with test case --- .../org/apache/spark/sql/DataFrame.scala | 50 ++++++++++--------- .../org/apache/spark/sql/DataFrameSuite.scala | 21 ++++++++ .../scala/org/apache/spark/sql/TestData.scala | 19 ------- 3 files changed, 47 insertions(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 1746cf0b27179..10bcd7a3f1713 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.jdbc.JDBCWriteDetails import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.sql.types.{NumericType, StructType, StructField, StringType} import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} import org.apache.spark.util.Utils @@ -772,32 +772,34 @@ class DataFrame private[sql]( @scala.annotation.varargs def describe(cols: String*): DataFrame = { - def aggCol(name: String = "") = s"'$name' as summary" + def stddevExpr(expr: Expression) = + Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr)))) + val statistics = List[(String, Expression => Expression)]( - "count" -> (expr => Count(expr)), - "mean" -> (expr => Average(expr)), - "stddev" -> (expr => Sqrt(Subtract(Average(Multiply(expr, expr)), - Multiply(Average(expr), Average(expr))))), - "min" -> (expr => Min(expr)), - "max" -> (expr => Max(expr))) - - val numCols = if (cols.isEmpty) numericColumns.map(_.prettyString) else cols - - // union all statistics starting from empty one - var description = selectExpr(aggCol()::numCols.toList:_*).limit(0) - for ((name, colToAgg) <- statistics) { - // generate next statistic aggregation - val nextAgg = if (numCols.nonEmpty) { - val aggCols = numCols.map(c => Column(colToAgg(Column(c).expr)).as(c)) - agg(aggCols.head, aggCols.tail:_*) - } else { - sqlContext.emptyDataFrame + "count" -> Count, + "mean" -> Average, + "stddev" -> stddevExpr, + "min" -> Min, + "max" -> Max) + + val aggCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList + + val localAgg = if (aggCols.nonEmpty) { + val aggExprs = statistics.flatMap { case (_, colToAgg) => + aggCols.map(c => Column(colToAgg(Column(c).expr)).as(c)) } - // add statistic name column - val nextStat = nextAgg.selectExpr(aggCol(name)::numCols.toList:_*) - description = description.unionAll(nextStat) + + agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + .grouped(aggCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) => + Row(statistic :: aggregation.toList: _*) + } + } else { + statistics.map { case (name, _) => Row(name) } } - description + + val schema = StructType(("summary" :: aggCols).map(StructField(_, StringType))) + val rowRdd = sqlContext.sparkContext.parallelize(localAgg) + sqlContext.createDataFrame(rowRdd, schema) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 0f37664ce1b06..ab9d1b93d05dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -437,6 +437,27 @@ class DataFrameSuite extends QueryTest { } test("describe") { + + val describeTestData = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + + val describeResult = Seq( + Row("count", 4, 4), + Row("mean", 33.0, 178.0), + Row("stddev", 16.583123951777, 10.0), + Row("min", 16, 164), + Row("max", 60, 192)) + + val emptyDescribeResult = Seq( + Row("count", 0, 0), + Row("mean", null, null), + Row("stddev", null, null), + Row("min", null, null), + Row("max", null, null)) + def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq val describeTwoCols = describeTestData.describe("age", "height") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index e4446cd5e0818..637f59b2e68ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -199,25 +199,6 @@ object TestData { Salary(1, 1000.0) :: Nil).toDF() salary.registerTempTable("salary") - case class PersonToDescribe(name: String, age: Int, height: Double) - val describeTestData = TestSQLContext.sparkContext.parallelize( - PersonToDescribe("Bob", 16, 176) :: - PersonToDescribe("Alice", 32, 164) :: - PersonToDescribe("David", 60, 192) :: - PersonToDescribe("Amy", 24, 180) :: Nil).toDF() - val describeResult = - Row("count", 4.0, 4.0) :: - Row("mean", 33.0, 178.0) :: - Row("stddev", 16.583123951777, 10.0) :: - Row("min", 16.0, 164.0) :: - Row("max", 60.0, 192.0) :: Nil - val emptyDescribeResult = - Row("count", 0, 0) :: - Row("mean", null, null) :: - Row("stddev", null, null) :: - Row("min", null, null) :: - Row("max", null, null) :: Nil - case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean) val complexData = TestSQLContext.sparkContext.parallelize(