From 236bb1849d0947a04d3e2fe0017aa2f293592cc0 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Mon, 11 Jan 2021 17:54:55 +0800 Subject: [PATCH] [SPARK-34003][SQL][FOLLOWUP] Fix Rule conflicts between PaddingAndLengthCheckForCharVarchar and ResolveAggregateFunctions --- .../sql/catalyst/analysis/Analyzer.scala | 53 +++++++++++-------- .../spark/sql/CharVarcharTestSuite.scala | 9 ++++ 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bf5dbb8200e87..37f6465f8817c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2381,13 +2381,16 @@ class Analyzer(override val catalogManager: CatalogManager) val unresolvedSortOrders = sortOrder.filter { s => !s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s) } - val aliasedOrdering = - unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) - val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) + val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) + + val aggregateWithExtraOrdering = aggregate.copy( + aggregateExpressions = aggregate.aggregateExpressions ++ aliasedOrdering) + val resolvedAggregate: Aggregate = - executeSameContext(aggregatedOrdering).asInstanceOf[Aggregate] - val resolvedAliasedOrdering: Seq[Alias] = - resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] + executeSameContext(aggregateWithExtraOrdering).asInstanceOf[Aggregate] + + val (reResolvedAggExprs, resolvedAliasedOrdering) = + resolvedAggregate.aggregateExpressions.splitAt(aggregate.aggregateExpressions.length) // If we pass the analysis check, then the ordering expressions should only reference to // aggregate expressions or grouping expressions, and it's safe to push them down to @@ -2401,24 +2404,25 @@ class Analyzer(override val catalogManager: CatalogManager) // expression instead. val needsPushDown = ArrayBuffer.empty[NamedExpression] val orderToAlias = unresolvedSortOrders.zip(aliasedOrdering) - val evaluatedOrderings = resolvedAliasedOrdering.zip(orderToAlias).map { - case (evaluated, (order, aliasOrder)) => - val index = originalAggExprs.indexWhere { - case Alias(child, _) => child semanticEquals evaluated.child - case other => other semanticEquals evaluated.child - } + val evaluatedOrderings = + resolvedAliasedOrdering.asInstanceOf[Seq[Alias]].zip(orderToAlias).map { + case (evaluated, (order, aliasOrder)) => + val index = reResolvedAggExprs.indexWhere { + case Alias(child, _) => child semanticEquals evaluated.child + case other => other semanticEquals evaluated.child + } - if (index == -1) { - if (CharVarcharUtils.getRawType(evaluated.metadata).nonEmpty) { - needsPushDown += aliasOrder - order.copy(child = aliasOrder) + if (index == -1) { + if (hasCharVarchar(evaluated)) { + needsPushDown += aliasOrder + order.copy(child = aliasOrder) + } else { + needsPushDown += evaluated + order.copy(child = evaluated.toAttribute) + } } else { - needsPushDown += evaluated - order.copy(child = evaluated.toAttribute) + order.copy(child = originalAggExprs(index).toAttribute) } - } else { - order.copy(child = originalAggExprs(index).toAttribute) - } } val sortOrdersMap = unresolvedSortOrders @@ -2443,6 +2447,13 @@ class Analyzer(override val catalogManager: CatalogManager) } } + def hasCharVarchar(expr: Alias): Boolean = { + expr.find { + case ne: NamedExpression => CharVarcharUtils.getRawType(ne.metadata).nonEmpty + case _ => false + }.nonEmpty + } + def containsAggregate(condition: Expression): Boolean = { condition.find(_.isInstanceOf[AggregateExpression]).isDefined } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index fb35d6cf8dacb..04ae8f71526fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -474,6 +474,15 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { checkAnswer(sql("SELECT v, sum(i) FROM t GROUP BY v ORDER BY v"), Row("c", 1)) } } + + test("SPARK-34003: fix char/varchar fails w/ order by functions") { + withTable("t") { + sql(s"CREATE TABLE t(v VARCHAR(3), i INT) USING $format") + sql("INSERT INTO t VALUES ('c', 1)") + checkAnswer(sql("SELECT substr(v, 1, 2), sum(i) FROM t GROUP BY v ORDER BY substr(v, 1, 2)"), + Row("c", 1)) + } + } } // Some basic char/varchar tests which doesn't rely on table implementation.