From e544ca3649ed6c31abdbd46eab9937adde1025b9 Mon Sep 17 00:00:00 2001 From: ulysses Date: Fri, 17 Jul 2020 13:41:33 +0800 Subject: [PATCH 1/4] init --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 33da482c4eea4..49f913dc8fee6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1004,7 +1004,7 @@ object EliminateSorts extends Rule[LogicalPlan] { private def isOrderIrrelevantAggs(aggs: Seq[NamedExpression]): Boolean = { def isOrderIrrelevantAggFunction(func: AggregateFunction): Boolean = func match { - case _: Min | _: Max | _: Count => true + case _: Min | _: Max | _: Count | _: CountIf | _: MaxMinBy => true // Arithmetic operations for floating-point values are order-sensitive // (they are not associative). case _: Sum | _: Average | _: CentralMomentAgg => From bc454b67f8879a596b90c7d91e0ac3c70a8d0a04 Mon Sep 17 00:00:00 2001 From: ulysses Date: Sat, 18 Jul 2020 17:19:16 +0800 Subject: [PATCH 2/4] fix --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 49f913dc8fee6..e39bf77cb3044 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1004,7 +1004,7 @@ object EliminateSorts extends Rule[LogicalPlan] { private def isOrderIrrelevantAggs(aggs: Seq[NamedExpression]): Boolean = { def isOrderIrrelevantAggFunction(func: AggregateFunction): Boolean = func match { - case _: Min | _: Max | _: Count | _: CountIf | _: MaxMinBy => true + case _: Min | _: Max | _: Count | _: MaxMinBy => true // Arithmetic operations for floating-point values are order-sensitive // (they are not associative). case _: Sum | _: Average | _: CentralMomentAgg => From 8aee8a58efc3498ce638843ede13c7bde8ee3092 Mon Sep 17 00:00:00 2001 From: ulysses Date: Mon, 20 Jul 2020 08:32:13 +0800 Subject: [PATCH 3/4] add ut --- .../sql/catalyst/optimizer/EliminateSortsSuite.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index e2b599a7c090c..cda7a2b0495bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{MaxBy, MinBy} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -319,4 +320,15 @@ class EliminateSortsSuite extends PlanTest { val correctAnswer = PushDownOptimizer.execute(noOrderByPlan.analyze) comparePlans(optimized, correctAnswer) } + + test("SPARK-32360: Add MaxMinBy to support eliminate sorts") { + Seq(MaxBy(Symbol("a"), Symbol("b")), MinBy(Symbol("a"), Symbol("b"))).foreach { agg => + val projectPlan = testRelation.select(Symbol("a"), Symbol("b")) + val unnecessaryOrderByPlan = projectPlan.orderBy(Symbol("a").asc) + val groupByPlan = unnecessaryOrderByPlan.groupBy(Symbol("a"))(agg) + val optimized = Optimize.execute(groupByPlan.analyze) + val correctAnswer = projectPlan.groupBy(Symbol("a"))(agg).analyze + comparePlans(optimized, correctAnswer) + } + } } From cd93b707dfd9e033a0580d688a19fe044af379f9 Mon Sep 17 00:00:00 2001 From: ulysses Date: Mon, 20 Jul 2020 11:58:58 +0800 Subject: [PATCH 4/4] fix --- .../spark/sql/catalyst/optimizer/EliminateSortsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index cda7a2b0495bd..4e10761751107 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -325,9 +325,9 @@ class EliminateSortsSuite extends PlanTest { Seq(MaxBy(Symbol("a"), Symbol("b")), MinBy(Symbol("a"), Symbol("b"))).foreach { agg => val projectPlan = testRelation.select(Symbol("a"), Symbol("b")) val unnecessaryOrderByPlan = projectPlan.orderBy(Symbol("a").asc) - val groupByPlan = unnecessaryOrderByPlan.groupBy(Symbol("a"))(agg) + val groupByPlan = unnecessaryOrderByPlan.groupBy(Symbol("a"), Symbol("b"))(agg) val optimized = Optimize.execute(groupByPlan.analyze) - val correctAnswer = projectPlan.groupBy(Symbol("a"))(agg).analyze + val correctAnswer = projectPlan.groupBy(Symbol("a"), Symbol("b"))(agg).analyze comparePlans(optimized, correctAnswer) } }