Skip to content

Commit

Permalink
[SPARK-34639][SQL] Always remove unnecessary Alias in Analyzer.resolv…
Browse files Browse the repository at this point in the history
…eExpression

### What changes were proposed in this pull request?

In `Analyzer.resolveExpression`, we have a parameter to decide if we should remove unnecessary `Alias` or not. This is over complicated and we can always remove unnecessary `Alias`.

This PR simplifies this part and removes the parameter.

### Why are the changes needed?

code cleanup

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

existing tests

Closes #31758 from cloud-fan/resolve.

Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
cloud-fan committed Mar 15, 2021
1 parent a0f3b72 commit be888b2
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1632,11 +1632,11 @@ class Analyzer(override val catalogManager: CatalogManager)
}

val resolvedGroupingExprs = a.groupingExpressions
.map(resolveExpressionByPlanChildren(_, planForResolve, trimAlias = true))
.map(resolveExpressionByPlanChildren(_, planForResolve))
.map(trimTopLevelGetStructFieldAlias)

val resolvedAggExprs = a.aggregateExpressions
.map(resolveExpressionByPlanChildren(_, planForResolve, trimAlias = true))
.map(resolveExpressionByPlanChildren(_, planForResolve))
.map(_.asInstanceOf[NamedExpression])

a.copy(resolvedGroupingExprs, resolvedAggExprs, a.child)
Expand All @@ -1648,15 +1648,15 @@ class Analyzer(override val catalogManager: CatalogManager)
// of GetStructField here.
case g: GroupingSets =>
val resolvedSelectedExprs = g.selectedGroupByExprs
.map(_.map(resolveExpressionByPlanChildren(_, g, trimAlias = true))
.map(_.map(resolveExpressionByPlanChildren(_, g))
.map(trimTopLevelGetStructFieldAlias))

val resolvedGroupingExprs = g.groupByExprs
.map(resolveExpressionByPlanChildren(_, g, trimAlias = true))
.map(resolveExpressionByPlanChildren(_, g))
.map(trimTopLevelGetStructFieldAlias)

val resolvedAggExprs = g.aggregations
.map(resolveExpressionByPlanChildren(_, g, trimAlias = true))
.map(resolveExpressionByPlanChildren(_, g))
.map(_.asInstanceOf[NamedExpression])

g.copy(resolvedSelectedExprs, resolvedGroupingExprs, g.child, resolvedAggExprs)
Expand Down Expand Up @@ -1895,26 +1895,22 @@ class Analyzer(override val catalogManager: CatalogManager)
plan: LogicalPlan,
resolveColumnByName: Seq[String] => Option[Expression],
resolveColumnByOrdinal: Int => Attribute,
trimAlias: Boolean,
throws: Boolean): Expression = {
def innerResolve(e: Expression, isTopLevel: Boolean): Expression = {
if (e.resolved) return e
e match {
case f: LambdaFunction if !f.bound => f
case GetColumnByOrdinal(ordinal, _) => resolveColumnByOrdinal(ordinal)
case u @ UnresolvedAttribute(nameParts) =>
val resolved = withPosition(u) {
resolveColumnByName(nameParts)
.orElse(resolveLiteralFunction(nameParts, u, plan))
.getOrElse(u)
}
val result = resolved match {
// When trimAlias = true, we will trim unnecessary alias of `GetStructField` and
// we won't trim the alias of top-level `GetStructField`. Since we will call
// CleanupAliases later in Analyzer, trim non top-level unnecessary alias of
// `GetStructField` here is safe.
case Alias(s: GetStructField, _) if trimAlias && !isTopLevel => s
case others => others
val result = withPosition(u) {
resolveColumnByName(nameParts).map {
// We trim unnecessary alias here. Note that, we cannot trim the alias at top-level,
// as we should resolve `UnresolvedAttribute` to a named expression. The caller side
// can trim the top-level alias if it's safe to do so. Since we will call
// CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe.
case Alias(child, _) if !isTopLevel => child
case other => other
}.orElse(resolveLiteralFunction(nameParts, u, plan)).getOrElse(u)
}
logDebug(s"Resolving $u to $result")
result
Expand Down Expand Up @@ -1962,7 +1958,6 @@ class Analyzer(override val catalogManager: CatalogManager)
assert(ordinal >= 0 && ordinal < plan.output.length)
plan.output(ordinal)
},
trimAlias = false,
throws = throws)
}

Expand All @@ -1972,28 +1967,22 @@ class Analyzer(override val catalogManager: CatalogManager)
*
* @param e The expression need to be resolved.
* @param q The LogicalPlan whose children are used to resolve expression's attribute.
* @param trimAlias When true, trim unnecessary alias of GetStructField`. Note that,
* we cannot trim the alias of top-level `GetStructField`, as we should
* resolve `UnresolvedAttribute` to a named expression. The caller side
* can trim the alias of top-level `GetStructField` if it's safe to do so.
* @return resolved Expression.
*/
def resolveExpressionByPlanChildren(
e: Expression,
q: LogicalPlan,
trimAlias: Boolean = false): Expression = {
q: LogicalPlan): Expression = {
resolveExpression(
e,
q,
resolveColumnByName = nameParts => {
q.resolveChildren(nameParts, resolver)
},
resolveColumnByOrdinal = ordinal => {
val candidates = q.children.flatMap(_.output)
assert(ordinal >= 0 && ordinal < candidates.length)
candidates.apply(ordinal)
assert(q.children.length == 1)
assert(ordinal >= 0 && ordinal < q.children.head.output.length)
q.children.head.output(ordinal)
},
trimAlias = trimAlias,
throws = true)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ case class GetMapValue(
override def sql: String = s"${child.sql}[${key.sql}]"
override def name: Option[String] = key match {
case NonNullLiteral(s, StringType) => Some(s.toString)
// For GetMapValue(Attr("a"), "b") that is resolved from `a.b`, the string "b" may be casted to
// the map key type by type coercion rules. We can still get the name "b".
case Cast(NonNullLiteral(s, StringType), _, _) => Some(s.toString)
case _ => None
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Stable
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction}
import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedFunction}
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand Down Expand Up @@ -80,11 +80,7 @@ class RelationalGroupedDataset protected[sql](
}
}

// Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
// will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
// make it a NamedExpression.
private[this] def alias(expr: Expression): NamedExpression = expr match {
case u: UnresolvedAttribute => UnresolvedAlias(u)
case expr: NamedExpression => expr
case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] =>
UnresolvedAlias(a, Some(Column.generateAlias))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ struct<ID:int,NST:string>
-- !query
SELECT ID, STRUCT(ST.C as STC, ST.D as STD).STD FROM tbl_x
-- !query schema
struct<ID:int,struct(ST.C AS C AS STC, ST.D AS D AS STD).STD:string>
struct<ID:int,struct(ST.C AS STC, ST.D AS STD).STD:string>
-- !query output
1 delta
2 eta
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ class UDFSuite extends QueryTest with SharedSparkSession {
Console.withOut(outputStream) {
spark.sql("SELECT f(a._1) FROM x").show
}
assert(outputStream.toString.contains("f(a._1 AS _1)"))
assert(outputStream.toString.contains("f(a._1)"))
}
}

Expand Down

0 comments on commit be888b2

Please sign in to comment.