From 6c4b29571a9490b8667b7827776659d8e4c18866 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 31 May 2018 00:23:25 +0800 Subject: [PATCH 1/2] [SPARK-24369][SQL] Correct handling for multiple distinct aggregations having the same argument set ## What changes were proposed in this pull request? This pr fixed an issue when having multiple distinct aggregations having the same argument set, e.g., ``` scala>: paste val df = sql( s"""SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) | FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y) """.stripMargin) java.lang.RuntimeException You hit a query analyzer bug. Please report your query to Spark user mailing list. ``` The root cause is that `RewriteDistinctAggregates` can't detect multiple distinct aggregations if they have the same argument set. This pr modified code so that `RewriteDistinctAggregates` could count the number of aggregate expressions with `isDistinct=true`. ## How was this patch tested? Added tests in `DataFrameAggregateSuite`. Author: Takeshi Yamamuro Closes #21443 from maropu/SPARK-24369. --- .../optimizer/RewriteDistinctAggregates.scala | 7 ++++--- .../apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../src/test/resources/sql-tests/inputs/group-by.sql | 6 +++++- .../test/resources/sql-tests/results/group-by.sql.out | 11 ++++++++++- 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 4448ace7105a4..bc898ab0dc723 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -115,7 +115,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Extract distinct aggregate expressions. - val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => + val distincgAggExpressions = aggExpressions.filter(_.isDistinct) + val distinctAggGroups = distincgAggExpressions.groupBy { e => val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet if (unfoldableChildren.nonEmpty) { // Only expand the unfoldable children @@ -132,7 +133,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Aggregation strategy can handle queries with a single distinct group. - if (distinctAggGroups.size > 1) { + if (distincgAggExpressions.size > 1) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { @@ -151,7 +152,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Setup unique distinct aggregate children. - val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b97a87a122406..b9452b58657a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -386,7 +386,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggregateExpressions.partition(_.isDistinct) if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { // This is a sanity check. We should not reach here when we have multiple distinct - // column sets. Our MultipleDistinctRewriter should take care this case. + // column sets. Our `RewriteDistinctAggregates` should take care this case. sys.error("You hit a query analyzer bug. Please report your query to " + "Spark user mailing list.") } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index c5070b734d521..2c18d6aaabdba 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -68,4 +68,8 @@ SELECT 1 from ( FROM (select 1 as x) a WHERE false ) b -where b.z != b.z +where b.z != b.z; + +-- SPARK-24369 multiple distinct aggregations having the same argument set +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y); diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index c1abc6dff754b..581aa1754ce14 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 27 -- !query 0 @@ -241,3 +241,12 @@ where b.z != b.z struct<1:int> -- !query 25 output + + +-- !query 26 +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y) +-- !query 26 schema +struct +-- !query 26 output +1.0 1.0 3 From 8386b4250d90eb369c85f02de7bbabe7a2ebbdaa Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 2 Jun 2018 18:41:11 -0700 Subject: [PATCH 2/2] another fix --- .../optimizer/RewriteDistinctAggregates.scala | 7 +++---- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 21 +++++++++++++++++++ 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index bc898ab0dc723..4448ace7105a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -115,8 +115,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Extract distinct aggregate expressions. - val distincgAggExpressions = aggExpressions.filter(_.isDistinct) - val distinctAggGroups = distincgAggExpressions.groupBy { e => + val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet if (unfoldableChildren.nonEmpty) { // Only expand the unfoldable children @@ -133,7 +132,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Aggregation strategy can handle queries with a single distinct group. - if (distincgAggExpressions.size > 1) { + if (distinctAggGroups.size > 1) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { @@ -152,7 +151,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Setup unique distinct aggregate children. - val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b9452b58657a4..be34387f6a874 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -384,7 +384,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + if (functionsWithDistinct.map(_.aggregateFunction.children.toSet).distinct.length > 1) { // This is a sanity check. We should not reach here when we have multiple distinct // column sets. Our `RewriteDistinctAggregates` should take care this case. sys.error("You hit a query analyzer bug. Please report your query to " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index b2aba8e72c5db..98a50fbd52b4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -69,6 +69,27 @@ class PlannerSuite extends SharedSQLContext { testPartialAggregationPlan(query) } + test("mixed aggregates with same distinct columns") { + def assertNoExpand(plan: SparkPlan): Unit = { + assert(plan.collect { case e: ExpandExec => e }.isEmpty) + } + + withTempView("v") { + Seq((1, 1.0, 1.0), (1, 2.0, 2.0)).toDF("i", "j", "k").createTempView("v") + // one distinct column + val query1 = sql("SELECT sum(DISTINCT j), max(DISTINCT j) FROM v GROUP BY i") + assertNoExpand(query1.queryExecution.executedPlan) + + // 2 distinct columns + val query2 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT j, k) FROM v GROUP BY i") + assertNoExpand(query2.queryExecution.executedPlan) + + // 2 distinct columns with different order + val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i") + assertNoExpand(query3.queryExecution.executedPlan) + } + } + test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { def checkPlan(fieldTypes: Seq[DataType]): Unit = { withTempView("testLimit") {