Skip to content

Commit

Permalink
[SPARK-6117] [SQL] create one aggregation and split it locally into r…
Browse files Browse the repository at this point in the history
…esulting DF, colocate test data with test case
  • Loading branch information
azagrebin committed Mar 18, 2015
1 parent ddb3950 commit f9056ac
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 43 deletions.
50 changes: 26 additions & 24 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
import org.apache.spark.sql.jdbc.JDBCWriteDetails
import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.types.{NumericType, StructType}
import org.apache.spark.sql.types.{NumericType, StructType, StructField, StringType}
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -772,32 +772,34 @@ class DataFrame private[sql](
@scala.annotation.varargs
def describe(cols: String*): DataFrame = {

def aggCol(name: String = "") = s"'$name' as summary"
def stddevExpr(expr: Expression) =
Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr))))

val statistics = List[(String, Expression => Expression)](
"count" -> (expr => Count(expr)),
"mean" -> (expr => Average(expr)),
"stddev" -> (expr => Sqrt(Subtract(Average(Multiply(expr, expr)),
Multiply(Average(expr), Average(expr))))),
"min" -> (expr => Min(expr)),
"max" -> (expr => Max(expr)))

val numCols = if (cols.isEmpty) numericColumns.map(_.prettyString) else cols

// union all statistics starting from empty one
var description = selectExpr(aggCol()::numCols.toList:_*).limit(0)
for ((name, colToAgg) <- statistics) {
// generate next statistic aggregation
val nextAgg = if (numCols.nonEmpty) {
val aggCols = numCols.map(c => Column(colToAgg(Column(c).expr)).as(c))
agg(aggCols.head, aggCols.tail:_*)
} else {
sqlContext.emptyDataFrame
"count" -> Count,
"mean" -> Average,
"stddev" -> stddevExpr,
"min" -> Min,
"max" -> Max)

val aggCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList

val localAgg = if (aggCols.nonEmpty) {
val aggExprs = statistics.flatMap { case (_, colToAgg) =>
aggCols.map(c => Column(colToAgg(Column(c).expr)).as(c))
}
// add statistic name column
val nextStat = nextAgg.selectExpr(aggCol(name)::numCols.toList:_*)
description = description.unionAll(nextStat)

agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
.grouped(aggCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) =>
Row(statistic :: aggregation.toList: _*)
}
} else {
statistics.map { case (name, _) => Row(name) }
}
description

val schema = StructType(("summary" :: aggCols).map(StructField(_, StringType)))
val rowRdd = sqlContext.sparkContext.parallelize(localAgg)
sqlContext.createDataFrame(rowRdd, schema)
}

/**
Expand Down
21 changes: 21 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,27 @@ class DataFrameSuite extends QueryTest {
}

test("describe") {

val describeTestData = Seq(
("Bob", 16, 176),
("Alice", 32, 164),
("David", 60, 192),
("Amy", 24, 180)).toDF("name", "age", "height")

val describeResult = Seq(
Row("count", 4, 4),
Row("mean", 33.0, 178.0),
Row("stddev", 16.583123951777, 10.0),
Row("min", 16, 164),
Row("max", 60, 192))

val emptyDescribeResult = Seq(
Row("count", 0, 0),
Row("mean", null, null),
Row("stddev", null, null),
Row("min", null, null),
Row("max", null, null))

def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq

val describeTwoCols = describeTestData.describe("age", "height")
Expand Down
19 changes: 0 additions & 19 deletions sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -199,25 +199,6 @@ object TestData {
Salary(1, 1000.0) :: Nil).toDF()
salary.registerTempTable("salary")

case class PersonToDescribe(name: String, age: Int, height: Double)
val describeTestData = TestSQLContext.sparkContext.parallelize(
PersonToDescribe("Bob", 16, 176) ::
PersonToDescribe("Alice", 32, 164) ::
PersonToDescribe("David", 60, 192) ::
PersonToDescribe("Amy", 24, 180) :: Nil).toDF()
val describeResult =
Row("count", 4.0, 4.0) ::
Row("mean", 33.0, 178.0) ::
Row("stddev", 16.583123951777, 10.0) ::
Row("min", 16.0, 164.0) ::
Row("max", 60.0, 192.0) :: Nil
val emptyDescribeResult =
Row("count", 0, 0) ::
Row("mean", null, null) ::
Row("stddev", null, null) ::
Row("min", null, null) ::
Row("max", null, null) :: Nil

case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean)
val complexData =
TestSQLContext.sparkContext.parallelize(
Expand Down

0 comments on commit f9056ac

Please sign in to comment.