diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 33da482c4eea4..e39bf77cb3044 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1004,7 +1004,7 @@ object EliminateSorts extends Rule[LogicalPlan] { private def isOrderIrrelevantAggs(aggs: Seq[NamedExpression]): Boolean = { def isOrderIrrelevantAggFunction(func: AggregateFunction): Boolean = func match { - case _: Min | _: Max | _: Count => true + case _: Min | _: Max | _: Count | _: MaxMinBy => true // Arithmetic operations for floating-point values are order-sensitive // (they are not associative). case _: Sum | _: Average | _: CentralMomentAgg => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index e2b599a7c090c..4e10761751107 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{MaxBy, MinBy} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -319,4 +320,15 @@ class EliminateSortsSuite extends PlanTest { val correctAnswer = PushDownOptimizer.execute(noOrderByPlan.analyze) comparePlans(optimized, correctAnswer) } + + test("SPARK-32360: Add MaxMinBy to support eliminate sorts") { + Seq(MaxBy(Symbol("a"), Symbol("b")), MinBy(Symbol("a"), Symbol("b"))).foreach { agg => + val projectPlan = testRelation.select(Symbol("a"), Symbol("b")) + val unnecessaryOrderByPlan = projectPlan.orderBy(Symbol("a").asc) + val groupByPlan = unnecessaryOrderByPlan.groupBy(Symbol("a"), Symbol("b"))(agg) + val optimized = Optimize.execute(groupByPlan.analyze) + val correctAnswer = projectPlan.groupBy(Symbol("a"), Symbol("b"))(agg).analyze + comparePlans(optimized, correctAnswer) + } + } }