From 2e5942dbe87285724547e0320ae4d03e82f6d30c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 5 Mar 2021 17:24:10 +0800 Subject: [PATCH 1/2] always remove unnecessary Alias in Analyzer.resolveExpression --- .../sql/catalyst/analysis/Analyzer.scala | 47 +++++++------------ .../spark/sql/RelationalGroupedDataset.scala | 6 +-- .../sql-tests/results/struct.sql.out | 2 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 2 +- 4 files changed, 21 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2c1fade0bafab..cca369f532df2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1628,11 +1628,11 @@ class Analyzer(override val catalogManager: CatalogManager) } val resolvedGroupingExprs = a.groupingExpressions - .map(resolveExpressionByPlanChildren(_, planForResolve, trimAlias = true)) + .map(resolveExpressionByPlanChildren(_, planForResolve)) .map(trimTopLevelGetStructFieldAlias) val resolvedAggExprs = a.aggregateExpressions - .map(resolveExpressionByPlanChildren(_, planForResolve, trimAlias = true)) + .map(resolveExpressionByPlanChildren(_, planForResolve)) .map(_.asInstanceOf[NamedExpression]) a.copy(resolvedGroupingExprs, resolvedAggExprs, a.child) @@ -1644,15 +1644,15 @@ class Analyzer(override val catalogManager: CatalogManager) // of GetStructField here. case g: GroupingSets => val resolvedSelectedExprs = g.selectedGroupByExprs - .map(_.map(resolveExpressionByPlanChildren(_, g, trimAlias = true)) + .map(_.map(resolveExpressionByPlanChildren(_, g)) .map(trimTopLevelGetStructFieldAlias)) val resolvedGroupingExprs = g.groupByExprs - .map(resolveExpressionByPlanChildren(_, g, trimAlias = true)) + .map(resolveExpressionByPlanChildren(_, g)) .map(trimTopLevelGetStructFieldAlias) val resolvedAggExprs = g.aggregations - .map(resolveExpressionByPlanChildren(_, g, trimAlias = true)) + .map(resolveExpressionByPlanChildren(_, g)) .map(_.asInstanceOf[NamedExpression]) g.copy(resolvedSelectedExprs, resolvedGroupingExprs, g.child, resolvedAggExprs) @@ -1891,7 +1891,6 @@ class Analyzer(override val catalogManager: CatalogManager) plan: LogicalPlan, resolveColumnByName: Seq[String] => Option[Expression], resolveColumnByOrdinal: Int => Attribute, - trimAlias: Boolean, throws: Boolean): Expression = { def innerResolve(e: Expression, isTopLevel: Boolean): Expression = { if (e.resolved) return e @@ -1899,18 +1898,15 @@ class Analyzer(override val catalogManager: CatalogManager) case f: LambdaFunction if !f.bound => f case GetColumnByOrdinal(ordinal, _) => resolveColumnByOrdinal(ordinal) case u @ UnresolvedAttribute(nameParts) => - val resolved = withPosition(u) { - resolveColumnByName(nameParts) - .orElse(resolveLiteralFunction(nameParts, u, plan)) - .getOrElse(u) - } - val result = resolved match { - // When trimAlias = true, we will trim unnecessary alias of `GetStructField` and - // we won't trim the alias of top-level `GetStructField`. Since we will call - // CleanupAliases later in Analyzer, trim non top-level unnecessary alias of - // `GetStructField` here is safe. - case Alias(s: GetStructField, _) if trimAlias && !isTopLevel => s - case others => others + val result = withPosition(u) { + resolveColumnByName(nameParts).map { + // We trim unnecessary alias here. Note that, we cannot trim the alias at top-level, + // as we should resolve `UnresolvedAttribute` to a named expression. The caller side + // can trim the top-level alias if it's safe to do so. Since we will call + // CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe. + case Alias(child, _) if !isTopLevel => child + case other => other + }.orElse(resolveLiteralFunction(nameParts, u, plan)).getOrElse(u) } logDebug(s"Resolving $u to $result") result @@ -1958,7 +1954,6 @@ class Analyzer(override val catalogManager: CatalogManager) assert(ordinal >= 0 && ordinal < plan.output.length) plan.output(ordinal) }, - trimAlias = false, throws = throws) } @@ -1968,16 +1963,11 @@ class Analyzer(override val catalogManager: CatalogManager) * * @param e The expression need to be resolved. * @param q The LogicalPlan whose children are used to resolve expression's attribute. - * @param trimAlias When true, trim unnecessary alias of GetStructField`. Note that, - * we cannot trim the alias of top-level `GetStructField`, as we should - * resolve `UnresolvedAttribute` to a named expression. The caller side - * can trim the alias of top-level `GetStructField` if it's safe to do so. * @return resolved Expression. */ def resolveExpressionByPlanChildren( e: Expression, - q: LogicalPlan, - trimAlias: Boolean = false): Expression = { + q: LogicalPlan): Expression = { resolveExpression( e, q, @@ -1985,11 +1975,10 @@ class Analyzer(override val catalogManager: CatalogManager) q.resolveChildren(nameParts, resolver) }, resolveColumnByOrdinal = ordinal => { - val candidates = q.children.flatMap(_.output) - assert(ordinal >= 0 && ordinal < candidates.length) - candidates.apply(ordinal) + assert(q.children.length == 1) + assert(ordinal >= 0 && ordinal < q.children.head.output.length) + q.children.head.output(ordinal) }, - trimAlias = trimAlias, throws = true) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index c40ce0f4777c6..7e735eecbac3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Stable import org.apache.spark.api.python.PythonEvalType import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} +import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -80,11 +80,7 @@ class RelationalGroupedDataset protected[sql]( } } - // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we - // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to - // make it a NamedExpression. private[this] def alias(expr: Expression): NamedExpression = expr match { - case u: UnresolvedAttribute => UnresolvedAlias(u) case expr: NamedExpression => expr case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => UnresolvedAlias(a, Some(Column.generateAlias)) diff --git a/sql/core/src/test/resources/sql-tests/results/struct.sql.out b/sql/core/src/test/resources/sql-tests/results/struct.sql.out index b3b0a6f343f57..3b2da6c85882b 100644 --- a/sql/core/src/test/resources/sql-tests/results/struct.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/struct.sql.out @@ -83,7 +83,7 @@ struct -- !query SELECT ID, STRUCT(ST.C as STC, ST.D as STD).STD FROM tbl_x -- !query schema -struct +struct -- !query output 1 delta 2 eta diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index b436dd3b6bd10..116e1ef1f9445 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -344,7 +344,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { Console.withOut(outputStream) { spark.sql("SELECT f(a._1) FROM x").show } - assert(outputStream.toString.contains("f(a._1 AS _1)")) + assert(outputStream.toString.contains("f(a._1)")) } } From 505476540dc20c480a7d7e540391c01fcf1912f7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 12 Mar 2021 23:08:59 +0800 Subject: [PATCH 2/2] fix a corner case --- .../spark/sql/catalyst/expressions/complexTypeExtractors.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 4413a3deaa641..fd24100802e79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -455,6 +455,9 @@ case class GetMapValue( override def sql: String = s"${child.sql}[${key.sql}]" override def name: Option[String] = key match { case NonNullLiteral(s, StringType) => Some(s.toString) + // For GetMapValue(Attr("a"), "b") that is resolved from `a.b`, the string "b" may be casted to + // the map key type by type coercion rules. We can still get the name "b". + case Cast(NonNullLiteral(s, StringType), _, _) => Some(s.toString) case _ => None }