diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4f474f4987dcf..8e8f8e3e7eda5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -102,11 +102,11 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } - def executeAndCheck(plan: LogicalPlan): LogicalPlan = { + def executeAndCheck(plan: LogicalPlan): LogicalPlan = AnalysisHelper.markInAnalyzer { val analyzed = execute(plan) try { checkAnalysis(analyzed) - EliminateBarriers(analyzed) + analyzed } catch { case e: AnalysisException => val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) @@ -203,7 +203,7 @@ class Analyzer( * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -213,7 +213,7 @@ class Analyzer( } def substituteCTE(plan: LogicalPlan, cteRelations: Seq[(String, LogicalPlan)]): LogicalPlan = { - plan transformDown { + plan resolveOperatorsDown { case u : UnresolvedRelation => cteRelations.find(x => resolver(x._1, u.tableIdentifier.table)) .map(_._2).getOrElse(u) @@ -231,10 +231,10 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => - child.transform { + child.resolveOperators { case p => p.transformExpressions { case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => val errorMessage = @@ -271,7 +271,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -491,7 +491,7 @@ class Analyzer( } // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case a if !a.childrenResolved => a // be sure all of the children are resolved. // Ensure group by expressions and aggregate expressions have been resolved. @@ -524,7 +524,7 @@ class Analyzer( } object ResolvePivot extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p @@ -694,7 +694,7 @@ class Analyzer( case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => @@ -755,12 +755,6 @@ class Analyzer( s"between $left and $right") right.collect { - // For `AnalysisBarrier`, recursively de-duplicate its child. - case oldVersion: AnalysisBarrier - if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => - val newVersion = dedupRight(left, oldVersion.child) - (oldVersion, AnalysisBarrier(newVersion)) - // Handle base relations that might appear more than once. case oldVersion: MultiInstanceRelation if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => @@ -805,6 +799,7 @@ class Analyzer( right case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) + // TODO(rxin): Why do we need transformUp here? right transformUp { case r if r == oldRelation => newRelation } transformUp { @@ -865,7 +860,7 @@ class Analyzer( private def dedupOuterReferencesInSubquery( plan: LogicalPlan, attrMap: AttributeMap[Attribute]): LogicalPlan = { - plan transformDown { case currentFragment => + plan resolveOperatorsDown { case currentFragment => currentFragment transformExpressions { case OuterReference(a: Attribute) => OuterReference(dedupAttr(a, attrMap)) @@ -891,7 +886,7 @@ class Analyzer( case _ => e.mapChildren(resolve(_, q)) } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -1086,7 +1081,7 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. @@ -1142,7 +1137,7 @@ class Analyzer( }} } - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case agg @ Aggregate(groups, aggs, child) if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && groups.exists(!_.resolved) => @@ -1166,9 +1161,8 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions - case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa case s @ Sort(order, _, child) @@ -1208,12 +1202,6 @@ class Analyzer( (exprs, plan) } else { plan match { - // For `AnalysisBarrier`, recursively resolve expressions and add missing attributes via - // its child. - case barrier: AnalysisBarrier => - val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(exprs, barrier.child) - (newExprs, AnalysisBarrier(newChild)) - case p: Project => // Resolving expressions against current plan. val maybeResolvedExprs = exprs.map(resolveExpression(_, p)) @@ -1270,7 +1258,7 @@ class Analyzer( object LookupFunctions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]() - plan.transformAllExpressions { + plan.resolveExpressions { case f: UnresolvedFunction if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f case f: UnresolvedFunction if catalog.isRegisteredFunction(f.name) => f @@ -1309,7 +1297,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1364,7 +1352,7 @@ class Analyzer( * resolved outer references are wrapped in an [[OuterReference]] */ private def resolveOuterReferences(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = { - plan transformDown { + plan resolveOperatorsDown { case q: LogicalPlan if q.childrenResolved && !q.resolved => q transformExpressions { case u @ UnresolvedAttribute(nameParts) => @@ -1446,7 +1434,7 @@ class Analyzer( /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1462,7 +1450,7 @@ class Analyzer( */ object ResolveSubqueryColumnAliases extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case u @ UnresolvedSubqueryColumnAliases(columnNames, child) if child.resolved => // Resolves output attributes if a query has alias names in its subquery: // e.g., SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2) @@ -1485,7 +1473,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -1511,9 +1499,7 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { - case Filter(cond, AnalysisBarrier(agg: Aggregate)) => - apply(Filter(cond, agg)).mapChildren(AnalysisBarrier) + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause @@ -1571,8 +1557,6 @@ class Analyzer( case ae: AnalysisException => f } - case Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => - apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. @@ -1692,7 +1676,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1752,7 +1736,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -1793,7 +1777,7 @@ class Analyzer( */ object FixNullability extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p: LogicalPlan if p.resolved => val childrenOutput = p.children.flatMap(c => c.output).groupBy(_.exprId).flatMap { @@ -2017,7 +2001,7 @@ class Analyzer( // We have to use transformDown at here to make sure the rule of // "Aggregate with Having clause" will be triggered. - def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case Filter(condition, _) if hasWindowFunction(condition) => failAnalysis("It is not allowed to use window functions inside WHERE and HAVING clauses") @@ -2077,7 +2061,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -2121,7 +2105,7 @@ class Analyzer( object ResolvedUuidExpressions extends Rule[LogicalPlan] { private lazy val random = new Random() - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if p.resolved => p case p => p transformExpressionsUp { case Uuid(None) => Uuid(Some(random.nextLong())) @@ -2136,7 +2120,7 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { @@ -2171,7 +2155,7 @@ class Analyzer( * Check and add proper window frames for all window functions. */ object ResolveWindowFrame extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case logical: LogicalPlan => logical transformExpressions { case WindowExpression(wf: WindowFunction, WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) @@ -2197,7 +2181,7 @@ class Analyzer( * Check and add order to [[AggregateWindowFunction]]s. */ object ResolveWindowOrder extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case logical: LogicalPlan => logical transformExpressions { case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => failAnalysis(s"Window function $wf requires window to be ordered, please add ORDER BY " + @@ -2215,7 +2199,7 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) @@ -2280,7 +2264,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2366,7 +2350,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2400,7 +2384,7 @@ class Analyzer( "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2422,8 +2406,13 @@ class Analyzer( * scoping information for attributes and can be removed once analysis is complete. */ object EliminateSubqueryAliases extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case SubqueryAlias(_, child) => child + // This is actually called in the beginning of the optimization phase, and as a result + // is using transformUp rather than resolveOperators. This is also often called in the + // + def apply(plan: LogicalPlan): LogicalPlan = AnalysisHelper.allowInvokingTransformsInAnalyzer { + plan transformUp { + case SubqueryAlias(_, child) => child + } } } @@ -2431,7 +2420,7 @@ object EliminateSubqueryAliases extends Rule[LogicalPlan] { * Removes [[Union]] operators from the plan if it just has one child. */ object EliminateUnions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case Union(children) if children.size == 1 => children.head } } @@ -2462,7 +2451,7 @@ object CleanupAliases extends Rule[LogicalPlan] { case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2491,19 +2480,12 @@ object CleanupAliases extends Rule[LogicalPlan] { } } -/** Remove the barrier nodes of analysis */ -object EliminateBarriers extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case AnalysisBarrier(child) => child - } -} - /** * Ignore event time watermark in batch query, which is only supported in Structured Streaming. * TODO: add this rule into analyzer rule list. */ object EliminateEventTimeWatermark extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case EventTimeWatermark(_, _, child) if !child.isStreaming => child } } @@ -2548,7 +2530,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = @@ -2636,7 +2618,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s. */ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { case e: CreateNamedStruct if !e.resolved => val children = e.children.grouped(2).flatMap { case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => @@ -2688,7 +2670,7 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { private def updateOuterReferenceInSubquery( plan: LogicalPlan, refExprs: Seq[Expression]): LogicalPlan = { - plan transformAllExpressions { case e => + plan resolveExpressions { case e => val outerAlias = refExprs.find(stripAlias(_).semanticEquals(stripOuterReference(e))) outerAlias match { @@ -2699,7 +2681,7 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { } def apply(plan: LogicalPlan): LogicalPlan = { - plan transform { + plan resolveOperatorsDown { case f @ Filter(_, a: Aggregate) if f.resolved => f transformExpressions { case s: SubqueryExpression if s.children.nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 49fe625b8fc6c..f9478a1c3cf4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -79,6 +79,9 @@ trait CheckAnalysis extends PredicateHelper { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { + + case p if p.analyzed => // Skip already analyzed sub-plans + case u: UnresolvedRelation => u.failAnalysis(s"Table or view not found: ${u.tableIdentifier}") @@ -364,10 +367,11 @@ trait CheckAnalysis extends PredicateHelper { } extendedCheckRules.foreach(_(plan)) plan.foreachUp { - case AnalysisBarrier(child) if !child.resolved => checkAnalysis(child) case o if !o.resolved => failAnalysis(s"unresolved operator ${o.simpleString}") case _ => } + + plan.setAnalyzed() } /** @@ -531,9 +535,8 @@ trait CheckAnalysis extends PredicateHelper { var foundNonEqualCorrelatedPred: Boolean = false - // Simplify the predicates before validating any unsupported correlation patterns - // in the plan. - BooleanSimplification(sub).foreachUp { + // Simplify the predicates before validating any unsupported correlation patterns in the plan. + AnalysisHelper.allowInvokingTransformsInAnalyzer { BooleanSimplification(sub).foreachUp { // Whitelist operators allowed in a correlated subquery // There are 4 categories: // 1. Operators that are allowed anywhere in a correlated subquery, and, @@ -635,6 +638,6 @@ trait CheckAnalysis extends PredicateHelper { // are not allowed to have any correlated expressions. case p => failOnOuterReferenceInSubTree(p) - } + }} } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index ab63131b07573..65a5888222f2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -82,7 +82,7 @@ object DecimalPrecision extends TypeCoercionRule { PromotePrecision(Cast(e, dataType)) } - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformUp { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // fix decimal precision for expressions case q => q.transformExpressionsUp( decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index f068bce3e9b69..bfe5169c25900 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -85,7 +85,7 @@ object ResolveHints { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => if (h.parameters.isEmpty) { // If there is no table alias specified, turn the entire subtree into a BroadcastHint. @@ -107,7 +107,7 @@ object ResolveHints { * This must be executed after all the other hint rules are executed. */ object RemoveAllHints extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case h: UnresolvedHint => h.child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 71ed75454cd4d..4edfe507a7580 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types.{StructField, StructType} * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. */ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case table: UnresolvedInlineTable if table.expressionsResolved => validateInputDimension(table) validateInputEvaluable(table) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index a214e59302cd9..7358f9ee36921 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -103,7 +103,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { }) ) - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index f9fd0df9e4010..860d20f897690 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -33,7 +33,7 @@ class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] { case _ => false } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 316aebdeaffa1..6bdb639011a17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -318,7 +318,7 @@ object TypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case s @ SetOperation(left, right) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) @@ -391,7 +391,7 @@ object TypeCoercion { } override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -453,7 +453,7 @@ object TypeCoercion { } override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -512,7 +512,7 @@ object TypeCoercion { private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -555,7 +555,7 @@ object TypeCoercion { object FunctionArgumentConversion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -670,7 +670,7 @@ object TypeCoercion { */ object Division extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.childrenResolved => e @@ -693,7 +693,7 @@ object TypeCoercion { */ object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) => val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) maybeCommonType.map { commonType => @@ -711,7 +711,7 @@ object TypeCoercion { */ object IfCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if !haveSameType(i.inputTypesForMerging) => @@ -731,7 +731,7 @@ object TypeCoercion { * Coerces NullTypes in the Stack expression to the column types of the corresponding positions. */ object StackCoercion extends TypeCoercionRule { - override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows => Stack(children.zipWithIndex.map { // The first child is the number of rows for stack. @@ -751,7 +751,8 @@ object TypeCoercion { */ case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p => p transformExpressionsUp { // Skip nodes if unresolved or empty children case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c @@ -773,7 +774,8 @@ object TypeCoercion { */ case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p => p transformExpressionsUp { // Skip nodes if unresolved or not enough children case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c @@ -801,7 +803,7 @@ object TypeCoercion { private val acceptedTypes = Seq(DateType, TimestampType, StringType) - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -822,7 +824,7 @@ object TypeCoercion { private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -961,7 +963,7 @@ object TypeCoercion { */ object WindowFrameCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case s @ WindowSpecDefinition(_, Seq(order), SpecifiedWindowFrame(RangeFrame, lower, upper)) if order.resolved => s.copy(frameSpecification = SpecifiedWindowFrame( @@ -999,7 +1001,7 @@ trait TypeCoercionRule extends Rule[LogicalPlan] with Logging { protected def coerceTypes(plan: LogicalPlan): LogicalPlan - private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan transformUp { + private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // No propagation required for leaf nodes. case q: LogicalPlan if q.children.isEmpty => q diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala index af1f9165b0044..a27aa845bf0ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -38,7 +38,7 @@ case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] { } override def apply(plan: LogicalPlan): LogicalPlan = - plan.transformAllExpressions(transformTimeZoneExprs) + plan.resolveExpressions(transformTimeZoneExprs) def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index 23eb78f914656..feeb6553d1066 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.internal.SQLConf * completely resolved during the batch of Resolution. */ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver val queryColumnNames = desc.viewQueryColumnNames diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala new file mode 100644 index 0000000000000..039acc1ea4fa8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.analysis.CheckAnalysis +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} +import org.apache.spark.util.Utils + + +/** + * [[AnalysisHelper]] defines some infrastructure for the query analyzer. In particular, in query + * analysis we don't want to repeatedly re-analyze sub-plans that have previously been analyzed. + * + * This trait defines a flag `analyzed` that can be set to true once analysis is done on the tree. + * This also provides a set of resolve methods that do not recurse down to sub-plans that have the + * analyzed flag set to true. + * + * The analyzer rules should use the various resolve methods, in lieu of the various transform + * methods defined in [[TreeNode]] and [[QueryPlan]]. + * + * To prevent accidental use of the transform methods, this trait also overrides the transform + * methods to throw exceptions in test mode, if they are used in the analyzer. + */ +trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => + + private var _analyzed: Boolean = false + + /** + * Recursively marks all nodes in this plan tree as analyzed. + * This should only be called by [[CheckAnalysis]]. + */ + private[catalyst] def setAnalyzed(): Unit = { + if (!_analyzed) { + _analyzed = true + children.foreach(_.setAnalyzed()) + } + } + + /** + * Returns true if this node and its children have already been gone through analysis and + * verification. Note that this is only an optimization used to avoid analyzing trees that + * have already been analyzed, and can be reset by transformations. + */ + def analyzed: Boolean = _analyzed + + /** + * Returns a copy of this node where `rule` has been recursively applied first to all of its + * children and then itself (post-order, bottom-up). When `rule` does not apply to a given node, + * it is left unchanged. This function is similar to `transformUp`, but skips sub-trees that + * have already been marked as analyzed. + * + * @param rule the function use to transform this nodes children + */ + def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + if (!analyzed) { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + val afterRuleOnChildren = mapChildren(_.resolveOperators(rule)) + if (self fastEquals afterRuleOnChildren) { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(self, identity[LogicalPlan]) + } + } else { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) + } + } + } + } else { + self + } + } + + /** Similar to [[resolveOperators]], but does it top-down. */ + def resolveOperatorsDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + if (!analyzed) { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + val afterRule = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(self, identity[LogicalPlan]) + } + + // Check if unchanged and then possibly return old copy to avoid gc churn. + if (self fastEquals afterRule) { + mapChildren(_.resolveOperatorsDown(rule)) + } else { + afterRule.mapChildren(_.resolveOperatorsDown(rule)) + } + } + } else { + self + } + } + + /** + * Recursively transforms the expressions of a tree, skipping nodes that have already + * been analyzed. + */ + def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = { + resolveOperators { + case p => p.transformExpressions(r) + } + } + + protected def assertNotAnalysisRule(): Unit = { + if (Utils.isTesting && + AnalysisHelper.inAnalyzer.get > 0 && + AnalysisHelper.resolveOperatorDepth.get == 0) { + throw new RuntimeException("This method should not be called in the analyzer") + } + } + + /** + * In analyzer, use [[resolveOperatorsDown()]] instead. If this is used in the analyzer, + * an exception will be thrown in test mode. It is however OK to call this function within + * the scope of a [[resolveOperatorsDown()]] call. + * @see [[TreeNode.transformDown()]]. + */ + override def transformDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + assertNotAnalysisRule() + super.transformDown(rule) + } + + /** + * Use [[resolveOperators()]] in the analyzer. + * @see [[TreeNode.transformUp()]] + */ + override def transformUp(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + assertNotAnalysisRule() + super.transformUp(rule) + } + + /** + * Use [[resolveExpressions()]] in the analyzer. + * @see [[QueryPlan.transformAllExpressions()]] + */ + override def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + assertNotAnalysisRule() + super.transformAllExpressions(rule) + } + +} + + +object AnalysisHelper { + + /** + * A thread local to track whether we are in a resolveOperator call (for the purpose of analysis). + * This is an int because resolve* calls might be be nested (e.g. a rule might trigger another + * query compilation within the rule itself), so we are tracking the depth here. + */ + private val resolveOperatorDepth: ThreadLocal[Int] = new ThreadLocal[Int] { + override def initialValue(): Int = 0 + } + + /** + * A thread local to track whether we are in the analysis phase of query compilation. This is an + * int rather than a boolean in case our analyzer recursively calls itself. + */ + private val inAnalyzer: ThreadLocal[Int] = new ThreadLocal[Int] { + override def initialValue(): Int = 0 + } + + def allowInvokingTransformsInAnalyzer[T](f: => T): T = { + resolveOperatorDepth.set(resolveOperatorDepth.get + 1) + try f finally { + resolveOperatorDepth.set(resolveOperatorDepth.get - 1) + } + } + + def markInAnalyzer[T](f: => T): T = { + inAnalyzer.set(inAnalyzer.get + 1) + try f finally { + inAnalyzer.set(inAnalyzer.get - 1) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index c486ad700f362..0e4456ac0e6a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -23,12 +23,12 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats -import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.types.StructType abstract class LogicalPlan extends QueryPlan[LogicalPlan] + with AnalysisHelper with LogicalPlanStats with QueryPlanConstraints with Logging { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 31f703d018aed..9fb50a5e565e0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -555,20 +555,6 @@ class AnalysisSuite extends AnalysisTest with Matchers { } } - test("SPARK-20392: analysis barrier") { - // [[AnalysisBarrier]] will be removed after analysis - checkAnalysis( - Project(Seq(UnresolvedAttribute("tbl.a")), - AnalysisBarrier(SubqueryAlias("tbl", testRelation))), - Project(testRelation.output, SubqueryAlias("tbl", testRelation))) - - // Verify we won't go through a plan wrapped in a barrier. - // Since we wrap an unresolved plan and analyzer won't go through it. It remains unresolved. - val barrier = AnalysisBarrier(Project(Seq(UnresolvedAttribute("tbl.b")), - SubqueryAlias("tbl", testRelation))) - assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'")) - } - test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") { val pythonUdf = PythonUDF("pyUDF", null, StructType(Seq(StructField("a", LongType))), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index bf569cb869428..aaab3ff1bf128 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Coalesce, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType /** - * This suite is used to test [[LogicalPlan]]'s `transformUp/transformDown` plus analysis barrier - * and make sure it can correctly skip sub-trees that have already been analyzed. + * This suite is used to test [[LogicalPlan]]'s `transformUp/transformDown`. */ class LogicalPlanSuite extends SparkFunSuite { private var invocationCount = 0 @@ -60,31 +59,6 @@ class LogicalPlanSuite extends SparkFunSuite { assert(invocationCount === 2) } - test("transformUp skips all ready resolved plans wrapped in analysis barrier") { - invocationCount = 0 - val plan = AnalysisBarrier(Project(Nil, Project(Nil, testRelation))) - plan transformUp function - - assert(invocationCount === 0) - - invocationCount = 0 - plan transformDown function - assert(invocationCount === 0) - } - - test("transformUp skips partially resolved plans wrapped in analysis barrier") { - invocationCount = 0 - val plan1 = AnalysisBarrier(Project(Nil, testRelation)) - val plan2 = Project(Nil, plan1) - plan2 transformUp function - - assert(invocationCount === 1) - - invocationCount = 0 - plan2 transformDown function - assert(invocationCount === 1) - } - test("isStreaming") { val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) val incrementalRelation = LocalRelation( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala new file mode 100644 index 0000000000000..9100e10ca0c09 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, Literal, NamedExpression} + + +class AnalysisHelperSuite extends SparkFunSuite { + + private var invocationCount = 0 + private val function: PartialFunction[LogicalPlan, LogicalPlan] = { + case p: Project => + invocationCount += 1 + p + } + + private val exprFunction: PartialFunction[Expression, Expression] = { + case e: Literal => + invocationCount += 1 + Literal.TrueLiteral + } + + private def projectExprs: Seq[NamedExpression] = Alias(Literal.TrueLiteral, "A")() :: Nil + + test("setAnalyze is recursive") { + val plan = Project(Nil, LocalRelation()) + plan.setAnalyzed() + assert(plan.find(!_.analyzed).isEmpty) + } + + test("resolveOperator runs on operators recursively") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.resolveOperators(function) + assert(invocationCount === 2) + } + + test("resolveOperatorsDown runs on operators recursively") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.resolveOperatorsDown(function) + assert(invocationCount === 2) + } + + test("resolveExpressions runs on operators recursively") { + invocationCount = 0 + val plan = Project(projectExprs, Project(projectExprs, LocalRelation())) + plan.resolveExpressions(exprFunction) + assert(invocationCount === 2) + } + + test("resolveOperator skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.setAnalyzed() + plan.resolveOperators(function) + assert(invocationCount === 0) + } + + test("resolveOperatorsDown skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.setAnalyzed() + plan.resolveOperatorsDown(function) + assert(invocationCount === 0) + } + + test("resolveExpressions skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(projectExprs, Project(projectExprs, LocalRelation())) + plan.setAnalyzed() + plan.resolveExpressions(exprFunction) + assert(invocationCount === 0) + } + + test("resolveOperator skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(Nil, LocalRelation()) + val plan2 = Project(Nil, plan1) + plan1.setAnalyzed() + plan2.resolveOperators(function) + assert(invocationCount === 1) + } + + test("resolveOperatorsDown skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(Nil, LocalRelation()) + val plan2 = Project(Nil, plan1) + plan1.setAnalyzed() + plan2.resolveOperatorsDown(function) + assert(invocationCount === 1) + } + + test("resolveExpressions skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(projectExprs, LocalRelation()) + val plan2 = Project(projectExprs, plan1) + plan1.setAnalyzed() + plan2.resolveExpressions(exprFunction) + assert(invocationCount === 1) + } + + test("do not allow transform in analyzer") { + val plan = Project(Nil, LocalRelation()) + // These should be OK since we are not in the analzyer + plan.transform { case p: Project => p } + plan.transformUp { case p: Project => p } + plan.transformDown { case p: Project => p } + plan.transformAllExpressions { case lit: Literal => lit } + + // The following should fail in the analyzer scope + AnalysisHelper.markInAnalyzer { + intercept[RuntimeException] { plan.transform { case p: Project => p } } + intercept[RuntimeException] { plan.transformUp { case p: Project => p } } + intercept[RuntimeException] { plan.transformDown { case p: Project => p } } + intercept[RuntimeException] { plan.transformAllExpressions { case lit: Literal => lit } } + } + } + + test("allow transform in resolveOperators in the analyzer") { + val plan = Project(Nil, LocalRelation()) + AnalysisHelper.markInAnalyzer { + plan.resolveOperators { case p: Project => p.transform { case p: Project => p } } + plan.resolveOperatorsDown { case p: Project => p.transform { case p: Project => p } } + plan.resolveExpressions { case lit: Literal => + Project(Nil, LocalRelation()).transform { case p: Project => p } + lit + } + } + } + + test("allow transform with allowInvokingTransformsInAnalyzer in the analyzer") { + val plan = Project(Nil, LocalRelation()) + AnalysisHelper.markInAnalyzer { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + plan.transform { case p: Project => p } + plan.transformUp { case p: Project => p } + plan.transformDown { case p: Project => p } + plan.transformAllExpressions { case lit: Literal => lit } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 39c0e102b69b2..3c9e743106260 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -254,7 +254,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options) if (writer.isPresent) { runCommand(df.sparkSession, "save") { - WriteToDataSourceV2(writer.get(), df.planWithBarrier) + WriteToDataSourceV2(writer.get(), df.logicalPlan) } } @@ -275,7 +275,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { sparkSession = df.sparkSession, className = source, partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, df.planWithBarrier) + options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) } } @@ -323,7 +323,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { InsertIntoTable( table = UnresolvedRelation(tableIdent), partition = Map.empty[String, Option[String]], - query = df.planWithBarrier, + query = df.logicalPlan, overwrite = mode == SaveMode.Overwrite, ifPartitionNotExists = false) } @@ -351,7 +351,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def assertNotPartitioned(operation: String): Unit = { if (partitioningColumns.isDefined) { - throw new AnalysisException( s"'$operation' does not support partitioning") + throw new AnalysisException(s"'$operation' does not support partitioning") } } @@ -459,9 +459,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitionColumnNames = partitioningColumns.getOrElse(Nil), bucketSpec = getBucketSpec) - runCommand(df.sparkSession, "saveAsTable") { - CreateTable(tableDesc, mode, Some(df.planWithBarrier)) - } + runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c97246f30220d..b63235ec827c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -196,7 +196,8 @@ class Dataset[T] private[sql]( } // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. - @transient private[sql] val planWithBarrier = AnalysisBarrier(logicalPlan) + // TODO(rxin): remove this later. + @transient private[sql] val planWithBarrier = logicalPlan /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the @@ -1857,7 +1858,7 @@ class Dataset[T] private[sql]( def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)).mapChildren(AnalysisBarrier) + CombineUnions(Union(logicalPlan, other.logicalPlan)) } /** @@ -1916,7 +1917,7 @@ class Dataset[T] private[sql]( // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, rightChild)).mapChildren(AnalysisBarrier) + CombineUnions(Union(logicalPlan, rightChild)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 36f6038aa9485..6bab21dca0cbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -49,7 +49,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( private implicit val kExprEnc = encoderFor(kEncoder) private implicit val vExprEnc = encoderFor(vEncoder) - private def logicalPlan = AnalysisBarrier(queryExecution.analyzed) + private def logicalPlan = queryExecution.analyzed private def sparkSession = queryExecution.sparkSession /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 39d9a95ca4710..ed130dc57ee5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -97,7 +97,7 @@ class CacheManager extends Logging { val inMemoryRelation = InMemoryRelation( sparkSession.sessionState.conf.useCompression, sparkSession.sessionState.conf.columnBatchSize, storageLevel, - sparkSession.sessionState.executePlan(AnalysisBarrier(planToCache)).executedPlan, + sparkSession.sessionState.executePlan(planToCache).executedPlan, tableName, planToCache) cachedData.add(CachedData(planToCache, inMemoryRelation)) @@ -173,7 +173,7 @@ class CacheManager extends Logging { // Remove the cache entry before we create a new one, so that we can have a different // physical plan. it.remove() - val plan = spark.sessionState.executePlan(AnalysisBarrier(cd.plan)).executedPlan + val plan = spark.sessionState.executePlan(cd.plan).executedPlan val newCache = InMemoryRelation( cacheBuilder = cd.cachedRepresentation.cacheBuilder.withCachedPlan(plan), logicalPlan = cd.plan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index c7f7e4d755cfd..e1faecedd20ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{EliminateBarriers, NoSuchTableException, Resolver} +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} @@ -891,7 +891,7 @@ object DDLUtils { * Throws exception if outputPath tries to overwrite inputpath. */ def verifyNotReadPath(query: LogicalPlan, outputPath: Path) : Unit = { - val inputPaths = EliminateBarriers(query).collect { + val inputPaths = query.collect { case LogicalRelation(r: HadoopFsRelation, _, _, _) => r.location.rootPaths }.flatten diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 7b129435c45db..e1b049b6ceaba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -131,7 +131,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast projectList } - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case CreateTable(tableDesc, mode, None) if DDLUtils.isDatasourceTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) CreateDataSourceTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) @@ -252,7 +252,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] table.partitionSchema.asNullable.toAttributes) } - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _) if DDLUtils.isDatasourceTable(tableMeta) => i.copy(table = readDataSourceTable(tableMeta)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index dfcf6c14fbef1..3170180b32b83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -39,7 +39,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { sparkSession.sessionState.conf.runSQLonFile && u.tableIdentifier.database.isDefined } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedRelation if maybeSQLFile(u) => try { val dataSource = DataSource( @@ -73,7 +73,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi // catalog is a def and not a val/lazy val as the latter would introduce a circular reference private def catalog = sparkSession.sessionState.catalog - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { // When we CREATE TABLE without specifying the table schema, we should fail the query if // bucketing information is specified, as we can't infer bucketing from data files currently. // Since the runtime inferred partition columns could be different from what user specified, @@ -365,7 +365,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case i @ InsertIntoTable(table, _, query, _, _) if table.resolved && query.resolved => table match { case relation: HiveTableRelation => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala deleted file mode 100644 index 147c0b61f5017..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.api.python.PythonEvalType -import org.apache.spark.sql.catalyst.expressions.PythonUDF -import org.apache.spark.sql.catalyst.plans.logical.AnalysisBarrier -import org.apache.spark.sql.functions.udf -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{LongType, StructField, StructType} - -class GroupedDatasetSuite extends QueryTest with SharedSQLContext { - import testImplicits._ - - private val scalaUDF = udf((x: Long) => { x + 1 }) - private lazy val datasetWithUDF = spark.range(1).toDF("s").select($"s", scalaUDF($"s")) - - private def assertContainsAnalysisBarrier(ds: Dataset[_], atLevel: Int = 1): Unit = { - assert(atLevel >= 0) - var children = Seq(ds.queryExecution.logical) - (1 to atLevel).foreach { _ => - children = children.flatMap(_.children) - } - val barriers = children.collect { - case ab: AnalysisBarrier => ab - } - assert(barriers.nonEmpty, s"Plan does not contain AnalysisBarrier at level $atLevel:\n" + - ds.queryExecution.logical) - } - - test("SPARK-24373: avoid running Analyzer rules twice on RelationalGroupedDataset") { - val groupByDataset = datasetWithUDF.groupBy() - val rollupDataset = datasetWithUDF.rollup("s") - val cubeDataset = datasetWithUDF.cube("s") - val pivotDataset = datasetWithUDF.groupBy().pivot("s", Seq(1, 2)) - datasetWithUDF.cache() - Seq(groupByDataset, rollupDataset, cubeDataset, pivotDataset).foreach { rgDS => - val df = rgDS.count() - assertContainsAnalysisBarrier(df) - assertCached(df) - } - - val flatMapGroupsInRDF = datasetWithUDF.groupBy().flatMapGroupsInR( - Array.emptyByteArray, - Array.emptyByteArray, - Array.empty, - StructType(Seq(StructField("s", LongType)))) - val flatMapGroupsInPandasDF = datasetWithUDF.groupBy().flatMapGroupsInPandas(PythonUDF( - "pyUDF", - null, - StructType(Seq(StructField("s", LongType))), - Seq.empty, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - true)) - Seq(flatMapGroupsInRDF, flatMapGroupsInPandasDF).foreach { df => - assertContainsAnalysisBarrier(df, 2) - assertCached(df) - } - datasetWithUDF.unpersist(true) - } - - test("SPARK-24373: avoid running Analyzer rules twice on KeyValueGroupedDataset") { - val kvDasaset = datasetWithUDF.groupByKey(_.getLong(0)) - datasetWithUDF.cache() - val mapValuesKVDataset = kvDasaset.mapValues(_.getLong(0)).reduceGroups(_ + _) - val keysKVDataset = kvDasaset.keys - val flatMapGroupsKVDataset = kvDasaset.flatMapGroups((k, _) => Seq(k)) - val aggKVDataset = kvDasaset.count() - val otherKVDataset = spark.range(1).groupByKey(_ + 1) - val cogroupKVDataset = kvDasaset.cogroup(otherKVDataset)((k, _, _) => Seq(k)) - Seq((mapValuesKVDataset, 1), - (keysKVDataset, 2), - (flatMapGroupsKVDataset, 2), - (aggKVDataset, 1), - (cogroupKVDataset, 2)).foreach { case (df, analysisBarrierDepth) => - assertContainsAnalysisBarrier(df, analysisBarrierDepth) - assertCached(df) - } - datasetWithUDF.unpersist(true) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index a0c197b06ddab..9fe83bb332a9a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -87,7 +87,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } } - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case c @ CreateTable(t, _, query) if DDLUtils.isHiveTable(t) => // Finds the database name if the name does not exist. val dbName = t.identifier.database.getOrElse(session.catalog.currentDatabase) @@ -114,7 +114,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case relation: HiveTableRelation if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty => val table = relation.tableMeta @@ -145,7 +145,7 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { * `PreprocessTableInsertion`. */ object HiveAnalysis extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case InsertIntoTable(r: HiveTableRelation, partSpec, query, overwrite, ifPartitionNotExists) if DDLUtils.isHiveTable(r.tableMeta) => InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, @@ -225,7 +225,7 @@ case class RelationConversions( } override def apply(plan: LogicalPlan): LogicalPlan = { - plan transformUp { + plan resolveOperators { // Write path case InsertIntoTable(r: HiveTableRelation, partition, query, overwrite, ifPartitionNotExists) // Inserting into partitioned table is not supported in Parquet/Orc data source (yet). diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 5d56f89c2271c..a1ce1ea936bbf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -170,21 +170,4 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql("EXPLAIN EXTENDED CODEGEN SELECT 1") } } - - test("SPARK-23021 AnalysisBarrier should not cut off explain output for parsed logical plans") { - val df = Seq((1, 1)).toDF("a", "b").groupBy("a").count().limit(1) - val outputStream = new java.io.ByteArrayOutputStream() - Console.withOut(outputStream) { - df.explain(true) - } - assert(outputStream.toString.replaceAll("""#\d+""", "#0").contains( - s"""== Parsed Logical Plan == - |GlobalLimit 1 - |+- LocalLimit 1 - | +- AnalysisBarrier - | +- Aggregate [a#0], [a#0, count(1) AS count#0L] - | +- Project [_1#0 AS a#0, _2#0 AS b#0] - | +- LocalRelation [_1#0, _2#0] - |""".stripMargin)) - } }