Skip to content

Commit

Permalink
[SPARK-34749][SQL] Simplify ResolveCreateNamedStruct
Browse files Browse the repository at this point in the history
This is a follow-up of apache#31808 and simplifies its fix to one line (excluding comments).

code simplification

no

N/A

Closes apache#31843 from cloud-fan/simplify.

Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Takeshi Yamamuro <[email protected]>
  • Loading branch information
cloud-fan committed Mar 17, 2021
1 parent 0922380 commit 698f5a1
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3840,8 +3840,6 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] {
val children = e.children.grouped(2).flatMap {
case Seq(NamePlaceholder, e: NamedExpression) if e.resolved =>
Seq(Literal(e.name), e)
case Seq(NamePlaceholder, e: ExtractValue) if e.resolved && e.name.isDefined =>
Seq(Literal(e.name.get), e)
case kv =>
kv
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
Expand Down Expand Up @@ -336,6 +336,14 @@ object CreateStruct {
*/
def apply(children: Seq[Expression]): CreateNamedStruct = {
CreateNamedStruct(children.zipWithIndex.flatMap {
// For multi-part column name like `struct(a.b.c)`, it may be resolved into:
// 1. Attribute if `a.b.c` is simply a qualified column name.
// 2. GetStructField if `a.b` refers to a struct-type column.
// 3. GetArrayStructFields if `a.b` refers to a array-of-struct-type column.
// 4. GetMapValue if `a.b` refers to a map-type column.
// We should always use the last part of the column name (`c` in the above example) as the
// alias name inside CreateNamedStruct.
case (u: UnresolvedAttribute, _) => Seq(Literal(u.nameParts.last), u)
case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e)
case (e: NamedExpression, _) => Seq(NamePlaceholder, e)
case (e, index) => Seq(Literal(s"col${index + 1}"), e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,7 @@ object ExtractValue {
}
}

trait ExtractValue extends Expression {
// The name that is used to extract the value.
def name: Option[String]
}
trait ExtractValue extends Expression

/**
* Returns the value of fields in the Struct `child`.
Expand Down Expand Up @@ -163,7 +160,6 @@ case class GetArrayStructFields(
override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def toString: String = s"$child.${field.name}"
override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}"
override def name: Option[String] = Some(field.name)

protected override def nullSafeEval(input: Any): Any = {
val array = input.asInstanceOf[ArrayData]
Expand Down Expand Up @@ -241,7 +237,6 @@ case class GetArrayItem(

override def toString: String = s"$child[$ordinal]"
override def sql: String = s"${child.sql}[${ordinal.sql}]"
override def name: Option[String] = None

override def left: Expression = child
override def right: Expression = ordinal
Expand Down Expand Up @@ -461,10 +456,6 @@ case class GetMapValue(

override def toString: String = s"$child[$key]"
override def sql: String = s"${child.sql}[${key.sql}]"
override def name: Option[String] = key match {
case NonNullLiteral(s, StringType) => Some(s.toString)
case _ => None
}

override def left: Expression = child
override def right: Expression = key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ class ExpressionParserSuite extends AnalysisTest {
assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis.
assertEqual(
"struct(a, b).b",
namedStruct(NamePlaceholder, 'a, NamePlaceholder, 'b).getField("b"))
namedStruct(Literal("a"), 'a, Literal("b"), 'b).getField("b"))
}

test("reference") {
Expand Down

0 comments on commit 698f5a1

Please sign in to comment.