From 698f5a13ed9293c01afe7a1d1638b6fcaa38f338 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 17 Mar 2021 21:21:54 +0900 Subject: [PATCH] [SPARK-34749][SQL] Simplify ResolveCreateNamedStruct This is a follow-up of https://github.com/apache/spark/pull/31808 and simplifies its fix to one line (excluding comments). code simplification no N/A Closes #31843 from cloud-fan/simplify. Authored-by: Wenchen Fan Signed-off-by: Takeshi Yamamuro --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 -- .../sql/catalyst/expressions/complexTypeCreator.scala | 10 +++++++++- .../catalyst/expressions/complexTypeExtractors.scala | 11 +---------- .../sql/catalyst/parser/ExpressionParserSuite.scala | 2 +- 4 files changed, 11 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f98f33b02f0dc..f4cdeab063ce7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3840,8 +3840,6 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { val children = e.children.grouped(2).flatMap { case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => Seq(Literal(e.name), e) - case Seq(NamePlaceholder, e: ExtractValue) if e.resolved && e.name.isDefined => - Seq(Literal(e.name.get), e) case kv => kv } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index cb59fbda2b3b9..1779d413e025d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -336,6 +336,14 @@ object CreateStruct { */ def apply(children: Seq[Expression]): CreateNamedStruct = { CreateNamedStruct(children.zipWithIndex.flatMap { + // For multi-part column name like `struct(a.b.c)`, it may be resolved into: + // 1. Attribute if `a.b.c` is simply a qualified column name. + // 2. GetStructField if `a.b` refers to a struct-type column. + // 3. GetArrayStructFields if `a.b` refers to a array-of-struct-type column. + // 4. GetMapValue if `a.b` refers to a map-type column. + // We should always use the last part of the column name (`c` in the above example) as the + // alias name inside CreateNamedStruct. + case (u: UnresolvedAttribute, _) => Seq(Literal(u.nameParts.last), u) case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e) case (e: NamedExpression, _) => Seq(NamePlaceholder, e) case (e, index) => Seq(Literal(s"col${index + 1}"), e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9b8014035944c..ef247efbe1a04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -94,10 +94,7 @@ object ExtractValue { } } -trait ExtractValue extends Expression { - // The name that is used to extract the value. - def name: Option[String] -} +trait ExtractValue extends Expression /** * Returns the value of fields in the Struct `child`. @@ -163,7 +160,6 @@ case class GetArrayStructFields( override def dataType: DataType = ArrayType(field.dataType, containsNull) override def toString: String = s"$child.${field.name}" override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}" - override def name: Option[String] = Some(field.name) protected override def nullSafeEval(input: Any): Any = { val array = input.asInstanceOf[ArrayData] @@ -241,7 +237,6 @@ case class GetArrayItem( override def toString: String = s"$child[$ordinal]" override def sql: String = s"${child.sql}[${ordinal.sql}]" - override def name: Option[String] = None override def left: Expression = child override def right: Expression = ordinal @@ -461,10 +456,6 @@ case class GetMapValue( override def toString: String = s"$child[$key]" override def sql: String = s"${child.sql}[${key.sql}]" - override def name: Option[String] = key match { - case NonNullLiteral(s, StringType) => Some(s.toString) - case _ => None - } override def left: Expression = child override def right: Expression = key diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 9f6a76b9228c5..9711cdc559c5c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -425,7 +425,7 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. assertEqual( "struct(a, b).b", - namedStruct(NamePlaceholder, 'a, NamePlaceholder, 'b).getField("b")) + namedStruct(Literal("a"), 'a, Literal("b"), 'b).getField("b")) } test("reference") {