diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b323fec9d2470..455123db0ecb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1491,65 +1491,6 @@ class Analyzer(override val catalogManager: CatalogManager) } } - /** - * Resolves the attribute and extract value expressions(s) by traversing the - * input expression in top down manner. The traversal is done in top-down manner as - * we need to skip over unbound lambda function expression. The lambda expressions are - * resolved in a different rule [[ResolveLambdaVariables]] - * - * Example : - * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" - * - * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]] - * - * Note : In this routine, the unresolved attributes are resolved from the input plan's - * children attributes. - * - * @param e The expression need to be resolved. - * @param q The LogicalPlan whose children are used to resolve expression's attribute. - * @param trimAlias When true, trim unnecessary alias of `GetStructField`. Note that, - * we cannot trim the alias of top-level `GetStructField`, as we should - * resolve `UnresolvedAttribute` to a named expression. The caller side - * can trim the alias of top-level `GetStructField` if it's safe to do so. - * @return resolved Expression. - */ - private def resolveExpressionTopDown( - e: Expression, - q: LogicalPlan, - trimAlias: Boolean = false): Expression = { - - def innerResolve(e: Expression, isTopLevel: Boolean): Expression = { - if (e.resolved) return e - e match { - case f: LambdaFunction if !f.bound => f - case u @ UnresolvedAttribute(nameParts) => - // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val resolved = - withPosition(u) { - q.resolveChildren(nameParts, resolver) - .orElse(resolveLiteralFunction(nameParts, u, q)) - .getOrElse(u) - } - val result = resolved match { - // As the comment of method `resolveExpressionTopDown`'s param `trimAlias` said, - // when trimAlias = true, we will trim unnecessary alias of `GetStructField` and - // we won't trim the alias of top-level `GetStructField`. Since we will call - // CleanupAliases later in Analyzer, trim non top-level unnecessary alias of - // `GetStructField` here is safe. - case Alias(s: GetStructField, _) if trimAlias && !isTopLevel => s - case others => others - } - logDebug(s"Resolving $u to $result") - result - case UnresolvedExtractValue(child, fieldExpr) if child.resolved => - ExtractValue(child, fieldExpr, resolver) - case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) - } - } - - innerResolve(e, isTopLevel = true) - } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p: LogicalPlan if !p.childrenResolved => p @@ -1580,9 +1521,9 @@ class Analyzer(override val catalogManager: CatalogManager) j.copy(right = dedupRight(left, right)) case f @ FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, _, _, left, right) => val leftRes = leftAttributes - .map(x => resolveExpressionBottomUp(x, left).asInstanceOf[Attribute]) + .map(x => resolveExpressionByPlanOutput(x, left).asInstanceOf[Attribute]) val rightRes = rightAttributes - .map(x => resolveExpressionBottomUp(x, right).asInstanceOf[Attribute]) + .map(x => resolveExpressionByPlanOutput(x, right).asInstanceOf[Attribute]) f.copy(leftAttributes = leftRes, rightAttributes = rightRes) // intersect/except will be rewritten to join at the begininng of optimizer. Here we need to // deduplicate the right side plan, so that we won't produce an invalid self-join later. @@ -1611,7 +1552,7 @@ class Analyzer(override val catalogManager: CatalogManager) // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => val newOrdering = - ordering.map(order => resolveExpressionBottomUp(order, child).asInstanceOf[SortOrder]) + ordering.map(order => resolveExpressionByPlanOutput(order, child).asInstanceOf[SortOrder]) Sort(newOrdering, global, child) // A special case for Generate, because the output of Generate should not be resolved by @@ -1619,7 +1560,7 @@ class Analyzer(override val catalogManager: CatalogManager) case g @ Generate(generator, _, _, _, _, _) if generator.resolved => g case g @ Generate(generator, join, outer, qualifier, output, child) => - val newG = resolveExpressionBottomUp(generator, child, throws = true) + val newG = resolveExpressionByPlanOutput(generator, child, throws = true) if (newG.fastEquals(generator)) { g } else { @@ -1645,11 +1586,11 @@ class Analyzer(override val catalogManager: CatalogManager) } val resolvedGroupingExprs = a.groupingExpressions - .map(resolveExpressionTopDown(_, planForResolve, trimAlias = true)) + .map(resolveExpressionByPlanChildren(_, planForResolve, trimAlias = true)) .map(trimTopLevelGetStructFieldAlias) val resolvedAggExprs = a.aggregateExpressions - .map(resolveExpressionTopDown(_, planForResolve, trimAlias = true)) + .map(resolveExpressionByPlanChildren(_, planForResolve, trimAlias = true)) .map(_.asInstanceOf[NamedExpression]) a.copy(resolvedGroupingExprs, resolvedAggExprs, a.child) @@ -1661,15 +1602,15 @@ class Analyzer(override val catalogManager: CatalogManager) // of GetStructField here. case g: GroupingSets => val resolvedSelectedExprs = g.selectedGroupByExprs - .map(_.map(resolveExpressionTopDown(_, g, trimAlias = true)) + .map(_.map(resolveExpressionByPlanChildren(_, g, trimAlias = true)) .map(trimTopLevelGetStructFieldAlias)) val resolvedGroupingExprs = g.groupByExprs - .map(resolveExpressionTopDown(_, g, trimAlias = true)) + .map(resolveExpressionByPlanChildren(_, g, trimAlias = true)) .map(trimTopLevelGetStructFieldAlias) val resolvedAggExprs = g.aggregations - .map(resolveExpressionTopDown(_, g, trimAlias = true)) + .map(resolveExpressionByPlanChildren(_, g, trimAlias = true)) .map(_.asInstanceOf[NamedExpression]) g.copy(resolvedSelectedExprs, resolvedGroupingExprs, g.child, resolvedAggExprs) @@ -1677,10 +1618,10 @@ class Analyzer(override val catalogManager: CatalogManager) case o: OverwriteByExpression if o.table.resolved => // The delete condition of `OverwriteByExpression` will be passed to the table // implementation and should be resolved based on the table schema. - o.copy(deleteExpr = resolveExpressionBottomUp(o.deleteExpr, o.table)) + o.copy(deleteExpr = resolveExpressionByPlanOutput(o.deleteExpr, o.table)) case o: OptimizeTable if o.table.resolved => - o.copy(predicate = resolveExpressionBottomUp(o.predicate, o.table)) + o.copy(predicate = resolveExpressionByPlanOutput(o.predicate, o.table)) case m @ MergeIntoTable(targetTable, sourceTable, _, _, _) if !m.duplicateResolved => m.copy(sourceTable = dedupRight(targetTable, sourceTable)) @@ -1698,10 +1639,12 @@ class Analyzer(override val catalogManager: CatalogManager) case _ => val newMatchedActions = m.matchedActions.map { case DeleteAction(deleteCondition) => - val resolvedDeleteCondition = deleteCondition.map(resolveExpressionTopDown(_, m)) + val resolvedDeleteCondition = deleteCondition.map( + resolveExpressionByPlanChildren(_, m)) DeleteAction(resolvedDeleteCondition) case UpdateAction(updateCondition, assignments) => - val resolvedUpdateCondition = updateCondition.map(resolveExpressionTopDown(_, m)) + val resolvedUpdateCondition = updateCondition.map( + resolveExpressionByPlanChildren(_, m)) // The update value can access columns from both target and source tables. UpdateAction( resolvedUpdateCondition, @@ -1712,14 +1655,14 @@ class Analyzer(override val catalogManager: CatalogManager) case InsertAction(insertCondition, assignments) => // The insert action is used when not matched, so its condition and value can only // access columns from the source table. - val resolvedInsertCondition = - insertCondition.map(resolveExpressionTopDown(_, Project(Nil, m.sourceTable))) + val resolvedInsertCondition = insertCondition.map( + resolveExpressionByPlanChildren(_, Project(Nil, m.sourceTable))) InsertAction( resolvedInsertCondition, resolveAssignments(assignments, m, resolveValuesWithSourceOnly = true)) case o => o } - val resolvedMergeCondition = resolveExpressionTopDown(m.mergeCondition, m) + val resolvedMergeCondition = resolveExpressionByPlanChildren(m.mergeCondition, m) m.copy(mergeCondition = resolvedMergeCondition, matchedActions = newMatchedActions, notMatchedActions = newNotMatchedActions) @@ -1730,7 +1673,7 @@ class Analyzer(override val catalogManager: CatalogManager) case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}") - q.mapExpressions(resolveExpressionTopDown(_, q)) + q.mapExpressions(resolveExpressionByPlanChildren(_, q)) } def resolveAssignments( @@ -1745,16 +1688,16 @@ class Analyzer(override val catalogManager: CatalogManager) assignments.map { assign => val resolvedKey = assign.key match { case c if !c.resolved => - resolveExpressionTopDown(c, Project(Nil, mergeInto.targetTable)) + resolveExpressionByPlanChildren(c, Project(Nil, mergeInto.targetTable)) case o => o } val resolvedValue = assign.value match { // The update values may contain target and/or source references. case c if !c.resolved => if (resolveValuesWithSourceOnly) { - resolveExpressionTopDown(c, Project(Nil, mergeInto.sourceTable)) + resolveExpressionByPlanChildren(c, Project(Nil, mergeInto.sourceTable)) } else { - resolveExpressionTopDown(c, mergeInto) + resolveExpressionByPlanChildren(c, mergeInto) } case o => o } @@ -1873,49 +1816,123 @@ class Analyzer(override val catalogManager: CatalogManager) } /** - * Resolves the attribute, column value and extract value expressions(s) by traversing the - * input expression in bottom-up manner. In order to resolve the nested complex type fields - * correctly, this function makes use of `throws` parameter to control when to raise an - * AnalysisException. + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by + * traversing the input expression in top-down manner. It must be top-down because we need to + * skip over unbound lambda function expression. The lambda expressions are resolved in a + * different place [[ResolveLambdaVariables]]. * * Example : - * SELECT a.b FROM t ORDER BY b[0].d + * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" * - * In the above example, in b needs to be resolved before d can be resolved. Given we are - * doing a bottom up traversal, it will first attempt to resolve d and fail as b has not - * been resolved yet. If `throws` is false, this function will handle the exception by - * returning the original attribute. In this case `d` will be resolved in subsequent passes - * after `b` is resolved. + * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]. */ - protected[sql] def resolveExpressionBottomUp( + private def resolveExpression( expr: Expression, plan: LogicalPlan, - throws: Boolean = false): Expression = { - if (expr.resolved) return expr - // Resolve expression in one round. - // If throws == false or the desired attribute doesn't exist - // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. - // Else, throw exception. - try { - expr transformUp { - case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal) + resolveColumnByName: Seq[String] => Option[Expression], + resolveColumnByOrdinal: Int => Attribute, + trimAlias: Boolean, + throws: Boolean): Expression = { + def innerResolve(e: Expression, isTopLevel: Boolean): Expression = { + if (e.resolved) return e + e match { + case f: LambdaFunction if !f.bound => f + case GetColumnByOrdinal(ordinal, _) => resolveColumnByOrdinal(ordinal) case u @ UnresolvedAttribute(nameParts) => - val result = - withPosition(u) { - plan.resolve(nameParts, resolver) - .orElse(resolveLiteralFunction(nameParts, u, plan)) - .getOrElse(u) - } + val resolved = withPosition(u) { + resolveColumnByName(nameParts) + .orElse(resolveLiteralFunction(nameParts, u, plan)) + .getOrElse(u) + } + val result = resolved match { + // When trimAlias = true, we will trim unnecessary alias of `GetStructField` and + // we won't trim the alias of top-level `GetStructField`. Since we will call + // CleanupAliases later in Analyzer, trim non top-level unnecessary alias of + // `GetStructField` here is safe. + case Alias(s: GetStructField, _) if trimAlias && !isTopLevel => s + case others => others + } logDebug(s"Resolving $u to $result") result - case UnresolvedExtractValue(child, fieldName) if child.resolved => - ExtractValue(child, fieldName, resolver) + case u @ UnresolvedExtractValue(child, fieldName) => + val newChild = innerResolve(child, isTopLevel = false) + if (newChild.resolved) { + ExtractValue(newChild, fieldName, resolver) + } else { + u.copy(child = newChild) + } + case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) } + } + + try { + innerResolve(expr, isTopLevel = true) } catch { - case a: AnalysisException if !throws => expr + case _: AnalysisException if !throws => expr } } + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the + * input plan's output attributes. In order to resolve the nested fields correctly, this function + * makes use of `throws` parameter to control when to raise an AnalysisException. + * + * Example : + * SELECT * FROM t ORDER BY a.b + * + * In the above example, after `a` is resolved to a struct-type column, we may fail to resolve `b` + * if there is no such nested field named "b". We should not fail and wait for other rules to + * resolve it if possible. + */ + def resolveExpressionByPlanOutput( + expr: Expression, + plan: LogicalPlan, + throws: Boolean = false): Expression = { + resolveExpression( + expr, + plan, + resolveColumnByName = nameParts => { + plan.resolve(nameParts, resolver) + }, + resolveColumnByOrdinal = ordinal => { + assert(ordinal >= 0 && ordinal < plan.output.length) + plan.output(ordinal) + }, + trimAlias = false, + throws = throws) + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the + * input plan's children output attributes. + * + * @param e The expression need to be resolved. + * @param q The LogicalPlan whose children are used to resolve expression's attribute. + * @param trimAlias When true, trim unnecessary alias of GetStructField`. Note that, + * we cannot trim the alias of top-level `GetStructField`, as we should + * resolve `UnresolvedAttribute` to a named expression. The caller side + * can trim the alias of top-level `GetStructField` if it's safe to do so. + * @return resolved Expression. + */ + def resolveExpressionByPlanChildren( + e: Expression, + q: LogicalPlan, + trimAlias: Boolean = false): Expression = { + resolveExpression( + e, + q, + resolveColumnByName = nameParts => { + q.resolveChildren(nameParts, resolver) + }, + resolveColumnByOrdinal = ordinal => { + val candidates = q.children.flatMap(_.output) + assert(ordinal >= 0 && ordinal < candidates.length) + candidates.apply(ordinal) + }, + trimAlias = trimAlias, + throws = true) + } + /** * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by * clauses. This rule is to convert ordinal positions to the corresponding expressions in the @@ -2053,7 +2070,7 @@ class Analyzer(override val catalogManager: CatalogManager) plan match { case p: Project => // Resolving expressions against current plan. - val maybeResolvedExprs = exprs.map(resolveExpressionBottomUp(_, p)) + val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, p)) // Recursively resolving expressions on the child of current plan. val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) // If some attributes used by expressions are resolvable only on the rewritten child @@ -2062,7 +2079,7 @@ class Analyzer(override val catalogManager: CatalogManager) (newExprs, Project(p.projectList ++ missingAttrs, newChild)) case a @ Aggregate(groupExprs, aggExprs, child) => - val maybeResolvedExprs = exprs.map(resolveExpressionBottomUp(_, a)) + val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, a)) val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { @@ -2074,20 +2091,20 @@ class Analyzer(override val catalogManager: CatalogManager) } case g: Generate => - val maybeResolvedExprs = exprs.map(resolveExpressionBottomUp(_, g)) + val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, g)) val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child) (newExprs, g.copy(unrequiredChildIndex = Nil, child = newChild)) // For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes // via its children. case u: UnaryNode if !u.isInstanceOf[Distinct] && !u.isInstanceOf[SubqueryAlias] => - val maybeResolvedExprs = exprs.map(resolveExpressionBottomUp(_, u)) + val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, u)) val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, u.child) (newExprs, u.withNewChildren(Seq(newChild))) // For other operators, we can't recursively resolve and add attributes via its children. case other => - (exprs.map(resolveExpressionBottomUp(_, other)), other) + (exprs.map(resolveExpressionByPlanOutput(_, other)), other) } } } @@ -3364,7 +3381,7 @@ class Analyzer(override val catalogManager: CatalogManager) } validateTopLevelTupleFields(deserializer, inputs) - val resolved = resolveExpressionBottomUp( + val resolved = resolveExpressionByPlanOutput( deserializer, LocalRelation(inputs), throws = true) val result = resolved transformDown { case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved =>