From bf4570b43dc10a1c7b66802f1e067bae877b3355 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 17 Mar 2021 21:21:54 +0900 Subject: [PATCH] [SPARK-34749][SQL] Simplify ResolveCreateNamedStruct ### What changes were proposed in this pull request? This is a follow-up of https://github.com/apache/spark/pull/31808 and simplifies its fix to one line (excluding comments). ### Why are the changes needed? code simplification ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? N/A Closes #31843 from cloud-fan/simplify. Authored-by: Wenchen Fan Signed-off-by: Takeshi Yamamuro --- .../spark/sql/catalyst/analysis/Analyzer.scala | 2 -- .../catalyst/expressions/complexTypeCreator.scala | 10 +++++++++- .../expressions/complexTypeExtractors.scala | 14 +------------- .../catalyst/parser/ExpressionParserSuite.scala | 2 +- 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bda9a836709f8..29d6115b9dcb8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3973,8 +3973,6 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { val children = e.children.grouped(2).flatMap { case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => Seq(Literal(e.name), e) - case Seq(NamePlaceholder, e: ExtractValue) if e.resolved && e.name.isDefined => - Seq(Literal(e.name.get), e) case kv => kv } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index d29da3ad2a4e4..152ed04b013e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -339,6 +339,14 @@ object CreateStruct { */ def apply(children: Seq[Expression]): CreateNamedStruct = { CreateNamedStruct(children.zipWithIndex.flatMap { + // For multi-part column name like `struct(a.b.c)`, it may be resolved into: + // 1. Attribute if `a.b.c` is simply a qualified column name. + // 2. GetStructField if `a.b` refers to a struct-type column. + // 3. GetArrayStructFields if `a.b` refers to a array-of-struct-type column. + // 4. GetMapValue if `a.b` refers to a map-type column. + // We should always use the last part of the column name (`c` in the above example) as the + // alias name inside CreateNamedStruct. + case (u: UnresolvedAttribute, _) => Seq(Literal(u.nameParts.last), u) case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e) case (e: NamedExpression, _) => Seq(NamePlaceholder, e) case (e, index) => Seq(Literal(s"col${index + 1}"), e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index fd24100802e79..139d9a584ccbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -90,10 +90,7 @@ object ExtractValue { } } -trait ExtractValue extends Expression { - // The name that is used to extract the value. - def name: Option[String] -} +trait ExtractValue extends Expression /** * Returns the value of fields in the Struct `child`. @@ -159,7 +156,6 @@ case class GetArrayStructFields( override def dataType: DataType = ArrayType(field.dataType, containsNull) override def toString: String = s"$child.${field.name}" override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}" - override def name: Option[String] = Some(field.name) protected override def nullSafeEval(input: Any): Any = { val array = input.asInstanceOf[ArrayData] @@ -237,7 +233,6 @@ case class GetArrayItem( override def toString: String = s"$child[$ordinal]" override def sql: String = s"${child.sql}[${ordinal.sql}]" - override def name: Option[String] = None override def left: Expression = child override def right: Expression = ordinal @@ -453,13 +448,6 @@ case class GetMapValue( override def toString: String = s"$child[$key]" override def sql: String = s"${child.sql}[${key.sql}]" - override def name: Option[String] = key match { - case NonNullLiteral(s, StringType) => Some(s.toString) - // For GetMapValue(Attr("a"), "b") that is resolved from `a.b`, the string "b" may be casted to - // the map key type by type coercion rules. We can still get the name "b". - case Cast(NonNullLiteral(s, StringType), _, _) => Some(s.toString) - case _ => None - } override def left: Expression = child override def right: Expression = key diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 0b304a799cdc5..a105298b22428 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -425,7 +425,7 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. assertEqual( "struct(a, b).b", - namedStruct(NamePlaceholder, 'a, NamePlaceholder, 'b).getField("b")) + namedStruct(Literal("a"), 'a, Literal("b"), 'b).getField("b")) } test("reference") {