From f6f2bccd6d887b50b035c492c88e9e76d1ef4754 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Jul 2018 17:06:06 -0700 Subject: [PATCH 01/12] [SPARK-24865] Remove AnalysisBarrier --- .../sql/catalyst/analysis/Analyzer.scala | 110 +++++++----------- .../sql/catalyst/analysis/CheckAnalysis.scala | 6 +- .../catalyst/analysis/DecimalPrecision.scala | 2 +- .../sql/catalyst/analysis/ResolveHints.scala | 4 +- .../analysis/ResolveInlineTables.scala | 2 +- .../ResolveTableValuedFunctions.scala | 2 +- .../SubstituteUnresolvedOrdinals.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 82 ++++++------- .../catalyst/analysis/timeZoneAnalysis.scala | 2 +- .../spark/sql/catalyst/analysis/view.scala | 2 +- .../catalyst/plans/logical/LogicalPlan.scala | 91 +++++++++++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 14 --- .../sql/catalyst/plans/LogicalPlanSuite.scala | 25 ---- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 7 +- .../spark/sql/KeyValueGroupedDataset.scala | 2 +- .../spark/sql/execution/CacheManager.scala | 4 +- .../datasources/DataSourceStrategy.scala | 4 +- .../sql/execution/datasources/rules.scala | 6 +- .../spark/sql/GroupedDatasetSuite.scala | 96 --------------- .../spark/sql/hive/HiveStrategies.scala | 8 +- 21 files changed, 207 insertions(+), 266 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 59c371eb1557b..f0939dc6ee8d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -106,7 +106,7 @@ class Analyzer( val analyzed = execute(plan) try { checkAnalysis(analyzed) - EliminateBarriers(analyzed) + analyzed } catch { case e: AnalysisException => val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) @@ -203,7 +203,7 @@ class Analyzer( * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -213,7 +213,7 @@ class Analyzer( } def substituteCTE(plan: LogicalPlan, cteRelations: Seq[(String, LogicalPlan)]): LogicalPlan = { - plan transformDown { + plan resolveOperatorsDown { case u : UnresolvedRelation => cteRelations.find(x => resolver(x._1, u.tableIdentifier.table)) .map(_._2).getOrElse(u) @@ -231,7 +231,7 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { @@ -271,7 +271,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -473,7 +473,7 @@ class Analyzer( } // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case a if !a.childrenResolved => a // be sure all of the children are resolved. // Ensure group by expressions and aggregate expressions have been resolved. @@ -506,7 +506,7 @@ class Analyzer( } object ResolvePivot extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p @@ -676,7 +676,7 @@ class Analyzer( case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => @@ -737,12 +737,6 @@ class Analyzer( s"between $left and $right") right.collect { - // For `AnalysisBarrier`, recursively de-duplicate its child. - case oldVersion: AnalysisBarrier - if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => - val newVersion = dedupRight(left, oldVersion.child) - (oldVersion, AnalysisBarrier(newVersion)) - // Handle base relations that might appear more than once. case oldVersion: MultiInstanceRelation if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => @@ -787,9 +781,9 @@ class Analyzer( right case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - right transformUp { + right resolveOperators { case r if r == oldRelation => newRelation - } transformUp { + } resolveOperators { case other => other transformExpressions { case a: Attribute => dedupAttr(a, attributeRewrites) @@ -847,7 +841,7 @@ class Analyzer( private def dedupOuterReferencesInSubquery( plan: LogicalPlan, attrMap: AttributeMap[Attribute]): LogicalPlan = { - plan transformDown { case currentFragment => + plan resolveOperatorsDown { case currentFragment => currentFragment transformExpressions { case OuterReference(a: Attribute) => OuterReference(dedupAttr(a, attrMap)) @@ -873,7 +867,7 @@ class Analyzer( case _ => e.mapChildren(resolve(_, q)) } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -1068,7 +1062,7 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. @@ -1124,7 +1118,7 @@ class Analyzer( }} } - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case agg @ Aggregate(groups, aggs, child) if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && groups.exists(!_.resolved) => @@ -1148,9 +1142,8 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions - case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa case s @ Sort(order, _, child) @@ -1190,12 +1183,6 @@ class Analyzer( (exprs, plan) } else { plan match { - // For `AnalysisBarrier`, recursively resolve expressions and add missing attributes via - // its child. - case barrier: AnalysisBarrier => - val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(exprs, barrier.child) - (newExprs, AnalysisBarrier(newChild)) - case p: Project => // Resolving expressions against current plan. val maybeResolvedExprs = exprs.map(resolveExpression(_, p)) @@ -1252,7 +1239,7 @@ class Analyzer( object LookupFunctions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]() - plan.transformAllExpressions { + plan.resolveExpressions { case f: UnresolvedFunction if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f case f: UnresolvedFunction if catalog.isRegisteredFunction(f.name) => f @@ -1291,7 +1278,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1346,7 +1333,7 @@ class Analyzer( * resolved outer references are wrapped in an [[OuterReference]] */ private def resolveOuterReferences(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = { - plan transformDown { + plan resolveOperatorsDown { case q: LogicalPlan if q.childrenResolved && !q.resolved => q transformExpressions { case u @ UnresolvedAttribute(nameParts) => @@ -1428,7 +1415,7 @@ class Analyzer( /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1444,7 +1431,7 @@ class Analyzer( */ object ResolveSubqueryColumnAliases extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case u @ UnresolvedSubqueryColumnAliases(columnNames, child) if child.resolved => // Resolves output attributes if a query has alias names in its subquery: // e.g., SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2) @@ -1467,7 +1454,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -1493,9 +1480,7 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { - case Filter(cond, AnalysisBarrier(agg: Aggregate)) => - apply(Filter(cond, agg)).mapChildren(AnalysisBarrier) + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause @@ -1553,8 +1538,6 @@ class Analyzer( case ae: AnalysisException => f } - case Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => - apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. @@ -1672,7 +1655,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1730,7 +1713,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -1771,7 +1754,7 @@ class Analyzer( */ object FixNullability extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p: LogicalPlan if p.resolved => val childrenOutput = p.children.flatMap(c => c.output).groupBy(_.exprId).flatMap { @@ -1995,7 +1978,7 @@ class Analyzer( // We have to use transformDown at here to make sure the rule of // "Aggregate with Having clause" will be triggered. - def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case Filter(condition, _) if hasWindowFunction(condition) => failAnalysis("It is not allowed to use window functions inside WHERE and HAVING clauses") @@ -2055,7 +2038,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -2099,7 +2082,7 @@ class Analyzer( object ResolvedUuidExpressions extends Rule[LogicalPlan] { private lazy val random = new Random() - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if p.resolved => p case p => p transformExpressionsUp { case Uuid(None) => Uuid(Some(random.nextLong())) @@ -2114,7 +2097,7 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { @@ -2139,7 +2122,7 @@ class Analyzer( * Check and add proper window frames for all window functions. */ object ResolveWindowFrame extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case logical: LogicalPlan => logical transformExpressions { case WindowExpression(wf: WindowFunction, WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) @@ -2165,7 +2148,7 @@ class Analyzer( * Check and add order to [[AggregateWindowFunction]]s. */ object ResolveWindowOrder extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case logical: LogicalPlan => logical transformExpressions { case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => failAnalysis(s"Window function $wf requires window to be ordered, please add ORDER BY " + @@ -2183,7 +2166,7 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) @@ -2248,7 +2231,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2334,7 +2317,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2368,7 +2351,7 @@ class Analyzer( "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2390,6 +2373,8 @@ class Analyzer( * scoping information for attributes and can be removed once analysis is complete. */ object EliminateSubqueryAliases extends Rule[LogicalPlan] { + // This is actually called in the beginning of the optimization phase, and as a result + // is using transformUp rather than resolveOperators. def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case SubqueryAlias(_, child) => child } @@ -2399,7 +2384,7 @@ object EliminateSubqueryAliases extends Rule[LogicalPlan] { * Removes [[Union]] operators from the plan if it just has one child. */ object EliminateUnions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case Union(children) if children.size == 1 => children.head } } @@ -2427,7 +2412,7 @@ object CleanupAliases extends Rule[LogicalPlan] { case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2456,19 +2441,12 @@ object CleanupAliases extends Rule[LogicalPlan] { } } -/** Remove the barrier nodes of analysis */ -object EliminateBarriers extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case AnalysisBarrier(child) => child - } -} - /** * Ignore event time watermark in batch query, which is only supported in Structured Streaming. * TODO: add this rule into analyzer rule list. */ object EliminateEventTimeWatermark extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case EventTimeWatermark(_, _, child) if !child.isStreaming => child } } @@ -2513,7 +2491,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = @@ -2601,7 +2579,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s. */ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { case e: CreateNamedStruct if !e.resolved => val children = e.children.grouped(2).flatMap { case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => @@ -2653,7 +2631,7 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { private def updateOuterReferenceInSubquery( plan: LogicalPlan, refExprs: Seq[Expression]): LogicalPlan = { - plan transformAllExpressions { case e => + plan resolveExpressions { case e => val outerAlias = refExprs.find(stripAlias(_).semanticEquals(stripOuterReference(e))) outerAlias match { @@ -2664,7 +2642,7 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { } def apply(plan: LogicalPlan): LogicalPlan = { - plan transform { + plan resolveOperatorsDown { case f @ Filter(_, a: Aggregate) if f.resolved => f transformExpressions { case s: SubqueryExpression if s.children.nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index af256b98b34f3..854e6dd44990b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -79,6 +79,9 @@ trait CheckAnalysis extends PredicateHelper { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { + + case p if p.analyzed => // Skip already analyzed sub-plans + case u: UnresolvedRelation => u.failAnalysis(s"Table or view not found: ${u.tableIdentifier}") @@ -364,10 +367,11 @@ trait CheckAnalysis extends PredicateHelper { } extendedCheckRules.foreach(_(plan)) plan.foreachUp { - case AnalysisBarrier(child) if !child.resolved => checkAnalysis(child) case o if !o.resolved => failAnalysis(s"unresolved operator ${o.simpleString}") case _ => } + + plan.foreach(_.setAnalyzed()) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index ab63131b07573..65a5888222f2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -82,7 +82,7 @@ object DecimalPrecision extends TypeCoercionRule { PromotePrecision(Cast(e, dataType)) } - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformUp { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // fix decimal precision for expressions case q => q.transformExpressionsUp( decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index f068bce3e9b69..bfe5169c25900 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -85,7 +85,7 @@ object ResolveHints { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => if (h.parameters.isEmpty) { // If there is no table alias specified, turn the entire subtree into a BroadcastHint. @@ -107,7 +107,7 @@ object ResolveHints { * This must be executed after all the other hint rules are executed. */ object RemoveAllHints extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case h: UnresolvedHint => h.child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 71ed75454cd4d..4edfe507a7580 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types.{StructField, StructType} * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. */ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case table: UnresolvedInlineTable if table.expressionsResolved => validateInputDimension(table) validateInputEvaluable(table) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index a214e59302cd9..7358f9ee36921 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -103,7 +103,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { }) ) - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index f9fd0df9e4010..860d20f897690 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -33,7 +33,7 @@ class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] { case _ => false } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 316aebdeaffa1..c3607bcc92309 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -318,7 +318,7 @@ object TypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case s @ SetOperation(left, right) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) @@ -391,7 +391,7 @@ object TypeCoercion { } override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -453,7 +453,7 @@ object TypeCoercion { } override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -512,7 +512,7 @@ object TypeCoercion { private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -555,7 +555,7 @@ object TypeCoercion { object FunctionArgumentConversion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -670,7 +670,7 @@ object TypeCoercion { */ object Division extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.childrenResolved => e @@ -693,7 +693,7 @@ object TypeCoercion { */ object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) => val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) maybeCommonType.map { commonType => @@ -711,7 +711,7 @@ object TypeCoercion { */ object IfCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if !haveSameType(i.inputTypesForMerging) => @@ -731,7 +731,7 @@ object TypeCoercion { * Coerces NullTypes in the Stack expression to the column types of the corresponding positions. */ object StackCoercion extends TypeCoercionRule { - override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows => Stack(children.zipWithIndex.map { // The first child is the number of rows for stack. @@ -751,17 +751,18 @@ object TypeCoercion { */ case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => - p transformExpressionsUp { - // Skip nodes if unresolved or empty children - case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c - case c @ Concat(children) if conf.concatBinaryAsString || - !children.map(_.dataType).forall(_ == BinaryType) => - val newChildren = c.children.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) - } - c.copy(children = newChildren) - } + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + case p => + p transformExpressionsUp { + // Skip nodes if unresolved or empty children + case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c + case c @ Concat(children) if conf.concatBinaryAsString || + !children.map(_.dataType).forall(_ == BinaryType) => + val newChildren = c.children.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + c.copy(children = newChildren) + } } } @@ -773,23 +774,24 @@ object TypeCoercion { */ case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => - p transformExpressionsUp { - // Skip nodes if unresolved or not enough children - case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c - case c @ Elt(children) => - val index = children.head - val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) - val newInputs = if (conf.eltOutputAsString || - !children.tail.map(_.dataType).forall(_ == BinaryType)) { - children.tail.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + case p => + p transformExpressionsUp { + // Skip nodes if unresolved or not enough children + case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c + case c @ Elt(children) => + val index = children.head + val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) + val newInputs = if (conf.eltOutputAsString || + !children.tail.map(_.dataType).forall(_ == BinaryType)) { + children.tail.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + } else { + children.tail } - } else { - children.tail - } - c.copy(children = newIndex +: newInputs) - } + c.copy(children = newIndex +: newInputs) + } } } @@ -801,7 +803,7 @@ object TypeCoercion { private val acceptedTypes = Seq(DateType, TimestampType, StringType) - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -822,7 +824,7 @@ object TypeCoercion { private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -961,7 +963,7 @@ object TypeCoercion { */ object WindowFrameCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case s @ WindowSpecDefinition(_, Seq(order), SpecifiedWindowFrame(RangeFrame, lower, upper)) if order.resolved => s.copy(frameSpecification = SpecifiedWindowFrame( @@ -999,7 +1001,7 @@ trait TypeCoercionRule extends Rule[LogicalPlan] with Logging { protected def coerceTypes(plan: LogicalPlan): LogicalPlan - private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan transformUp { + private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // No propagation required for leaf nodes. case q: LogicalPlan if q.children.isEmpty => q diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala index af1f9165b0044..a27aa845bf0ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -38,7 +38,7 @@ case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] { } override def apply(plan: LogicalPlan): LogicalPlan = - plan.transformAllExpressions(transformTimeZoneExprs) + plan.resolveExpressions(transformTimeZoneExprs) def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index 20216087b0158..7dc3470bc5ec5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.internal.SQLConf * completely resolved during the batch of Resolution. */ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver val queryColumnNames = desc.viewQueryColumnNames diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index c486ad700f362..a64a50cfb57b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils abstract class LogicalPlan @@ -33,6 +34,96 @@ abstract class LogicalPlan with QueryPlanConstraints with Logging { + private var _analyzed: Boolean = false + + /** + * Marks this plan as already analyzed. This should only be called by [[CheckAnalysis]]. + */ + private[catalyst] def setAnalyzed(): Unit = { _analyzed = true } + + /** + * Returns true if this node and its children have already been gone through analysis and + * verification. Note that this is only an optimization used to avoid analyzing trees that + * have already been analyzed, and can be reset by transformations. + */ + def analyzed: Boolean = _analyzed + + /** + * Returns a copy of this node where `rule` has been recursively applied first to all of its + * children and then itself (post-order, bottom-up). When `rule` does not apply to a given node, + * it is left unchanged. This function is similar to `transformUp`, but skips sub-trees that + * have already been marked as analyzed. + * + * @param rule the function use to transform this nodes children + */ + def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + if (!analyzed) { + val afterRuleOnChildren = mapChildren(_.resolveOperators(rule)) + if (this fastEquals afterRuleOnChildren) { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[LogicalPlan]) + } + } else { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) + } + } + } else { + this + } + } + + /** Similar to [[resolveOperators]], but does it top-down. */ + def resolveOperatorsDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + if (!analyzed) { + val afterRule = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[LogicalPlan]) + } + + // Check if unchanged and then possibly return old copy to avoid gc churn. + if (this fastEquals afterRule) { + mapChildren(_.resolveOperatorsDown(rule)) + } else { + afterRule.mapChildren(_.resolveOperatorsDown(rule)) + } + } else { + this + } + } + + /** + * Recursively transforms the expressions of a tree, skipping nodes that have already + * been analyzed. + */ + def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = { + this resolveOperators { + case p => p.transformExpressions(r) + } + } + + protected def assertNotAnalysisRule(): Unit = { + if (Utils.isTesting && getClass.getName.contains("Logical")) { + if (Thread.currentThread.getStackTrace.exists(_.getClassName.contains("Analyzer"))) { + sys.error("This method should not be called in the analyzer") + } + } + } + + override def transformDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + assertNotAnalysisRule() + super.transformDown(rule) + } + + override def transformUp(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + assertNotAnalysisRule() + super.transformUp(rule) + } + + override def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + assertNotAnalysisRule() + super.transformAllExpressions(rule) + } + /** Returns true if this subtree has data from a streaming data source. */ def isStreaming: Boolean = children.exists(_.isStreaming == true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index bbcdf6c1b8481..2f9d05effc33e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -545,20 +545,6 @@ class AnalysisSuite extends AnalysisTest with Matchers { } } - test("SPARK-20392: analysis barrier") { - // [[AnalysisBarrier]] will be removed after analysis - checkAnalysis( - Project(Seq(UnresolvedAttribute("tbl.a")), - AnalysisBarrier(SubqueryAlias("tbl", testRelation))), - Project(testRelation.output, SubqueryAlias("tbl", testRelation))) - - // Verify we won't go through a plan wrapped in a barrier. - // Since we wrap an unresolved plan and analyzer won't go through it. It remains unresolved. - val barrier = AnalysisBarrier(Project(Seq(UnresolvedAttribute("tbl.b")), - SubqueryAlias("tbl", testRelation))) - assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'")) - } - test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") { val pythonUdf = PythonUDF("pyUDF", null, StructType(Seq(StructField("a", LongType))), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index bf569cb869428..a3b9d52ddf719 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -60,31 +60,6 @@ class LogicalPlanSuite extends SparkFunSuite { assert(invocationCount === 2) } - test("transformUp skips all ready resolved plans wrapped in analysis barrier") { - invocationCount = 0 - val plan = AnalysisBarrier(Project(Nil, Project(Nil, testRelation))) - plan transformUp function - - assert(invocationCount === 0) - - invocationCount = 0 - plan transformDown function - assert(invocationCount === 0) - } - - test("transformUp skips partially resolved plans wrapped in analysis barrier") { - invocationCount = 0 - val plan1 = AnalysisBarrier(Project(Nil, testRelation)) - val plan2 = Project(Nil, plan1) - plan2 transformUp function - - assert(invocationCount === 1) - - invocationCount = 0 - plan2 transformDown function - assert(invocationCount === 1) - } - test("isStreaming") { val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) val incrementalRelation = LocalRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 90bea2d676e22..6e8584ced30e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -275,7 +275,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { sparkSession = df.sparkSession, className = source, partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) + options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c97246f30220d..b63235ec827c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -196,7 +196,8 @@ class Dataset[T] private[sql]( } // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. - @transient private[sql] val planWithBarrier = AnalysisBarrier(logicalPlan) + // TODO(rxin): remove this later. + @transient private[sql] val planWithBarrier = logicalPlan /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the @@ -1857,7 +1858,7 @@ class Dataset[T] private[sql]( def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)).mapChildren(AnalysisBarrier) + CombineUnions(Union(logicalPlan, other.logicalPlan)) } /** @@ -1916,7 +1917,7 @@ class Dataset[T] private[sql]( // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, rightChild)).mapChildren(AnalysisBarrier) + CombineUnions(Union(logicalPlan, rightChild)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 36f6038aa9485..6bab21dca0cbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -49,7 +49,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( private implicit val kExprEnc = encoderFor(kEncoder) private implicit val vExprEnc = encoderFor(vEncoder) - private def logicalPlan = AnalysisBarrier(queryExecution.analyzed) + private def logicalPlan = queryExecution.analyzed private def sparkSession = queryExecution.sparkSession /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 39d9a95ca4710..ed130dc57ee5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -97,7 +97,7 @@ class CacheManager extends Logging { val inMemoryRelation = InMemoryRelation( sparkSession.sessionState.conf.useCompression, sparkSession.sessionState.conf.columnBatchSize, storageLevel, - sparkSession.sessionState.executePlan(AnalysisBarrier(planToCache)).executedPlan, + sparkSession.sessionState.executePlan(planToCache).executedPlan, tableName, planToCache) cachedData.add(CachedData(planToCache, inMemoryRelation)) @@ -173,7 +173,7 @@ class CacheManager extends Logging { // Remove the cache entry before we create a new one, so that we can have a different // physical plan. it.remove() - val plan = spark.sessionState.executePlan(AnalysisBarrier(cd.plan)).executedPlan + val plan = spark.sessionState.executePlan(cd.plan).executedPlan val newCache = InMemoryRelation( cacheBuilder = cd.cachedRepresentation.cacheBuilder.withCachedPlan(plan), logicalPlan = cd.plan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 7b129435c45db..e1b049b6ceaba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -131,7 +131,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast projectList } - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case CreateTable(tableDesc, mode, None) if DDLUtils.isDatasourceTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) CreateDataSourceTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) @@ -252,7 +252,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] table.partitionSchema.asNullable.toAttributes) } - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _) if DDLUtils.isDatasourceTable(tableMeta) => i.copy(table = readDataSourceTable(tableMeta)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index cab00251622b8..b5ce6427f82e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -39,7 +39,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { sparkSession.sessionState.conf.runSQLonFile && u.tableIdentifier.database.isDefined } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedRelation if maybeSQLFile(u) => try { val dataSource = DataSource( @@ -73,7 +73,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi // catalog is a def and not a val/lazy val as the latter would introduce a circular reference private def catalog = sparkSession.sessionState.catalog - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { // When we CREATE TABLE without specifying the table schema, we should fail the query if // bucketing information is specified, as we can't infer bucketing from data files currently. // Since the runtime inferred partition columns could be different from what user specified, @@ -365,7 +365,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case i @ InsertIntoTable(table, _, query, _, _) if table.resolved && query.resolved => table match { case relation: HiveTableRelation => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala deleted file mode 100644 index 147c0b61f5017..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.api.python.PythonEvalType -import org.apache.spark.sql.catalyst.expressions.PythonUDF -import org.apache.spark.sql.catalyst.plans.logical.AnalysisBarrier -import org.apache.spark.sql.functions.udf -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{LongType, StructField, StructType} - -class GroupedDatasetSuite extends QueryTest with SharedSQLContext { - import testImplicits._ - - private val scalaUDF = udf((x: Long) => { x + 1 }) - private lazy val datasetWithUDF = spark.range(1).toDF("s").select($"s", scalaUDF($"s")) - - private def assertContainsAnalysisBarrier(ds: Dataset[_], atLevel: Int = 1): Unit = { - assert(atLevel >= 0) - var children = Seq(ds.queryExecution.logical) - (1 to atLevel).foreach { _ => - children = children.flatMap(_.children) - } - val barriers = children.collect { - case ab: AnalysisBarrier => ab - } - assert(barriers.nonEmpty, s"Plan does not contain AnalysisBarrier at level $atLevel:\n" + - ds.queryExecution.logical) - } - - test("SPARK-24373: avoid running Analyzer rules twice on RelationalGroupedDataset") { - val groupByDataset = datasetWithUDF.groupBy() - val rollupDataset = datasetWithUDF.rollup("s") - val cubeDataset = datasetWithUDF.cube("s") - val pivotDataset = datasetWithUDF.groupBy().pivot("s", Seq(1, 2)) - datasetWithUDF.cache() - Seq(groupByDataset, rollupDataset, cubeDataset, pivotDataset).foreach { rgDS => - val df = rgDS.count() - assertContainsAnalysisBarrier(df) - assertCached(df) - } - - val flatMapGroupsInRDF = datasetWithUDF.groupBy().flatMapGroupsInR( - Array.emptyByteArray, - Array.emptyByteArray, - Array.empty, - StructType(Seq(StructField("s", LongType)))) - val flatMapGroupsInPandasDF = datasetWithUDF.groupBy().flatMapGroupsInPandas(PythonUDF( - "pyUDF", - null, - StructType(Seq(StructField("s", LongType))), - Seq.empty, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - true)) - Seq(flatMapGroupsInRDF, flatMapGroupsInPandasDF).foreach { df => - assertContainsAnalysisBarrier(df, 2) - assertCached(df) - } - datasetWithUDF.unpersist(true) - } - - test("SPARK-24373: avoid running Analyzer rules twice on KeyValueGroupedDataset") { - val kvDasaset = datasetWithUDF.groupByKey(_.getLong(0)) - datasetWithUDF.cache() - val mapValuesKVDataset = kvDasaset.mapValues(_.getLong(0)).reduceGroups(_ + _) - val keysKVDataset = kvDasaset.keys - val flatMapGroupsKVDataset = kvDasaset.flatMapGroups((k, _) => Seq(k)) - val aggKVDataset = kvDasaset.count() - val otherKVDataset = spark.range(1).groupByKey(_ + 1) - val cogroupKVDataset = kvDasaset.cogroup(otherKVDataset)((k, _, _) => Seq(k)) - Seq((mapValuesKVDataset, 1), - (keysKVDataset, 2), - (flatMapGroupsKVDataset, 2), - (aggKVDataset, 1), - (cogroupKVDataset, 2)).foreach { case (df, analysisBarrierDepth) => - assertContainsAnalysisBarrier(df, analysisBarrierDepth) - assertCached(df) - } - datasetWithUDF.unpersist(true) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index a0c197b06ddab..9fe83bb332a9a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -87,7 +87,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } } - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case c @ CreateTable(t, _, query) if DDLUtils.isHiveTable(t) => // Finds the database name if the name does not exist. val dbName = t.identifier.database.getOrElse(session.catalog.currentDatabase) @@ -114,7 +114,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case relation: HiveTableRelation if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty => val table = relation.tableMeta @@ -145,7 +145,7 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { * `PreprocessTableInsertion`. */ object HiveAnalysis extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case InsertIntoTable(r: HiveTableRelation, partSpec, query, overwrite, ifPartitionNotExists) if DDLUtils.isHiveTable(r.tableMeta) => InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, @@ -225,7 +225,7 @@ case class RelationConversions( } override def apply(plan: LogicalPlan): LogicalPlan = { - plan transformUp { + plan resolveOperators { // Write path case InsertIntoTable(r: HiveTableRelation, partition, query, overwrite, ifPartitionNotExists) // Inserting into partitioned table is not supported in Parquet/Orc data source (yet). From 8ccafcab20a70df7a625912fdf4e43be7fb87954 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Jul 2018 17:08:47 -0700 Subject: [PATCH 02/12] Minimize change --- .../sql/catalyst/analysis/TypeCoercion.scala | 58 +++++++++---------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index c3607bcc92309..6bdb639011a17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -751,18 +751,18 @@ object TypeCoercion { */ case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { - case p => - p transformExpressionsUp { - // Skip nodes if unresolved or empty children - case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c - case c @ Concat(children) if conf.concatBinaryAsString || - !children.map(_.dataType).forall(_ == BinaryType) => - val newChildren = c.children.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) - } - c.copy(children = newChildren) - } + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or empty children + case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c + case c @ Concat(children) if conf.concatBinaryAsString || + !children.map(_.dataType).forall(_ == BinaryType) => + val newChildren = c.children.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + c.copy(children = newChildren) + } } } @@ -774,24 +774,24 @@ object TypeCoercion { */ case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { - case p => - p transformExpressionsUp { - // Skip nodes if unresolved or not enough children - case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c - case c @ Elt(children) => - val index = children.head - val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) - val newInputs = if (conf.eltOutputAsString || - !children.tail.map(_.dataType).forall(_ == BinaryType)) { - children.tail.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) - } - } else { - children.tail + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or not enough children + case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c + case c @ Elt(children) => + val index = children.head + val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) + val newInputs = if (conf.eltOutputAsString || + !children.tail.map(_.dataType).forall(_ == BinaryType)) { + children.tail.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) } - c.copy(children = newIndex +: newInputs) - } + } else { + children.tail + } + c.copy(children = newIndex +: newInputs) + } } } From 0afa7eadb6ae6b97e28eae93d6d2eb83b1f930fd Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Jul 2018 17:20:35 -0700 Subject: [PATCH 03/12] Fix bug --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f0939dc6ee8d1..63f552c19845a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -781,9 +781,10 @@ class Analyzer( right case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - right resolveOperators { + // TODO(rxin): Why do we need transformUp here? + right transformUp { case r if r == oldRelation => newRelation - } resolveOperators { + } transformUp { case other => other transformExpressions { case a: Attribute => dedupAttr(a, attributeRewrites) From 738e99c615f45cfb6e5882a822dafcb8472b78ea Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Jul 2018 18:10:57 -0700 Subject: [PATCH 04/12] More bug fixes --- .../sql/catalyst/analysis/Analyzer.scala | 29 ++++++++++++------- .../sql/catalyst/analysis/CheckAnalysis.scala | 3 +- .../catalyst/plans/logical/LogicalPlan.scala | 20 +++++++++++-- .../inputs/table-valued-functions.sql | 1 + 4 files changed, 39 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 63f552c19845a..7ab6ffbb977cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -234,7 +234,8 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => - child.transform { + // TODO(rxin): Check with Herman whether the next line is OK. + child.resolveOperators { case p => p.transformExpressions { case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => val errorMessage = @@ -782,14 +783,16 @@ class Analyzer( case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) // TODO(rxin): Why do we need transformUp here? - right transformUp { - case r if r == oldRelation => newRelation - } transformUp { - case other => other transformExpressions { - case a: Attribute => - dedupAttr(a, attributeRewrites) - case s: SubqueryExpression => - s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) + LogicalPlan.bypassTransformAnalyzerCheck { + right transformUp { + case r if r == oldRelation => newRelation + } transformUp { + case other => other transformExpressions { + case a: Attribute => + dedupAttr(a, attributeRewrites) + case s: SubqueryExpression => + s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) + } } } } @@ -2376,8 +2379,12 @@ class Analyzer( object EliminateSubqueryAliases extends Rule[LogicalPlan] { // This is actually called in the beginning of the optimization phase, and as a result // is using transformUp rather than resolveOperators. - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case SubqueryAlias(_, child) => child + def apply(plan: LogicalPlan): LogicalPlan = { + LogicalPlan.bypassTransformAnalyzerCheck { + plan transformUp { + case SubqueryAlias(_, child) => child + } + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 854e6dd44990b..056be4086174e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -537,7 +537,8 @@ trait CheckAnalysis extends PredicateHelper { // Simplify the predicates before validating any unsupported correlation patterns // in the plan. - BooleanSimplification(sub).foreachUp { + // TODO(rxin): Why did this need to call BooleanSimplification??? + sub.foreachUp { // Whitelist operators allowed in a correlated subquery // There are 4 categories: // 1. Operators that are allowed anywhere in a correlated subquery, and, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index a64a50cfb57b3..a0a18ccfc94a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -28,6 +28,20 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils +object LogicalPlan { + private val bypassTransformAnalyzerCheckFlag = new ThreadLocal[Boolean] { + override def initialValue(): Boolean = false + } + + def bypassTransformAnalyzerCheck[T](p: => T): T = { + bypassTransformAnalyzerCheckFlag.set(true) + try p finally { + bypassTransformAnalyzerCheckFlag.set(false) + } + } +} + + abstract class LogicalPlan extends QueryPlan[LogicalPlan] with LogicalPlanStats @@ -102,9 +116,11 @@ abstract class LogicalPlan } protected def assertNotAnalysisRule(): Unit = { - if (Utils.isTesting && getClass.getName.contains("Logical")) { + if (Utils.isTesting && !LogicalPlan.bypassTransformAnalyzerCheckFlag.get) { if (Thread.currentThread.getStackTrace.exists(_.getClassName.contains("Analyzer"))) { - sys.error("This method should not be called in the analyzer") + val e = new RuntimeException("This method should not be called in the analyzer") + e.printStackTrace() + throw e } } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql index 72cd8ca9d8722..26b7668b1fe44 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -27,3 +27,4 @@ EXPLAIN select * from RaNgE(2); -- cross-join table valued functions EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3); +table-valued-functions.sql \ No newline at end of file From 83ffa51f4b165152dea214be4d73dd518d742a56 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Jul 2018 22:21:31 -0700 Subject: [PATCH 05/12] more fixes --- .../sql/catalyst/analysis/Analyzer.scala | 26 ++++------- .../catalyst/plans/logical/LogicalPlan.scala | 46 ++++++++++--------- .../sql/hive/execution/HiveExplainSuite.scala | 17 ------- 3 files changed, 35 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7ab6ffbb977cf..5a461bd07299a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -783,16 +783,14 @@ class Analyzer( case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) // TODO(rxin): Why do we need transformUp here? - LogicalPlan.bypassTransformAnalyzerCheck { - right transformUp { - case r if r == oldRelation => newRelation - } transformUp { - case other => other transformExpressions { - case a: Attribute => - dedupAttr(a, attributeRewrites) - case s: SubqueryExpression => - s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) - } + right transformUp { + case r if r == oldRelation => newRelation + } transformUp { + case other => other transformExpressions { + case a: Attribute => + dedupAttr(a, attributeRewrites) + case s: SubqueryExpression => + s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } } @@ -2379,12 +2377,8 @@ class Analyzer( object EliminateSubqueryAliases extends Rule[LogicalPlan] { // This is actually called in the beginning of the optimization phase, and as a result // is using transformUp rather than resolveOperators. - def apply(plan: LogicalPlan): LogicalPlan = { - LogicalPlan.bypassTransformAnalyzerCheck { - plan transformUp { - case SubqueryAlias(_, child) => child - } - } + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case SubqueryAlias(_, child) => child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index a0a18ccfc94a4..581e6b942320b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -29,14 +29,14 @@ import org.apache.spark.util.Utils object LogicalPlan { - private val bypassTransformAnalyzerCheckFlag = new ThreadLocal[Boolean] { - override def initialValue(): Boolean = false + private val bypassTransformAnalyzerCheckFlag = new ThreadLocal[Int] { + override def initialValue(): Int = 0 } def bypassTransformAnalyzerCheck[T](p: => T): T = { - bypassTransformAnalyzerCheckFlag.set(true) + bypassTransformAnalyzerCheckFlag.set(bypassTransformAnalyzerCheckFlag.get() + 1) try p finally { - bypassTransformAnalyzerCheckFlag.set(false) + bypassTransformAnalyzerCheckFlag.set(bypassTransformAnalyzerCheckFlag.get() - 1) } } } @@ -72,14 +72,16 @@ abstract class LogicalPlan */ def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { if (!analyzed) { - val afterRuleOnChildren = mapChildren(_.resolveOperators(rule)) - if (this fastEquals afterRuleOnChildren) { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(this, identity[LogicalPlan]) - } - } else { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) + LogicalPlan.bypassTransformAnalyzerCheck { + val afterRuleOnChildren = mapChildren(_.resolveOperators(rule)) + if (this fastEquals afterRuleOnChildren) { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[LogicalPlan]) + } + } else { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) + } } } } else { @@ -90,15 +92,17 @@ abstract class LogicalPlan /** Similar to [[resolveOperators]], but does it top-down. */ def resolveOperatorsDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { if (!analyzed) { - val afterRule = CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(this, identity[LogicalPlan]) - } + LogicalPlan.bypassTransformAnalyzerCheck { + val afterRule = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[LogicalPlan]) + } - // Check if unchanged and then possibly return old copy to avoid gc churn. - if (this fastEquals afterRule) { - mapChildren(_.resolveOperatorsDown(rule)) - } else { - afterRule.mapChildren(_.resolveOperatorsDown(rule)) + // Check if unchanged and then possibly return old copy to avoid gc churn. + if (this fastEquals afterRule) { + mapChildren(_.resolveOperatorsDown(rule)) + } else { + afterRule.mapChildren(_.resolveOperatorsDown(rule)) + } } } else { this @@ -116,7 +120,7 @@ abstract class LogicalPlan } protected def assertNotAnalysisRule(): Unit = { - if (Utils.isTesting && !LogicalPlan.bypassTransformAnalyzerCheckFlag.get) { + if (Utils.isTesting && LogicalPlan.bypassTransformAnalyzerCheckFlag.get == 0) { if (Thread.currentThread.getStackTrace.exists(_.getClassName.contains("Analyzer"))) { val e = new RuntimeException("This method should not be called in the analyzer") e.printStackTrace() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 5d56f89c2271c..a1ce1ea936bbf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -170,21 +170,4 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql("EXPLAIN EXTENDED CODEGEN SELECT 1") } } - - test("SPARK-23021 AnalysisBarrier should not cut off explain output for parsed logical plans") { - val df = Seq((1, 1)).toDF("a", "b").groupBy("a").count().limit(1) - val outputStream = new java.io.ByteArrayOutputStream() - Console.withOut(outputStream) { - df.explain(true) - } - assert(outputStream.toString.replaceAll("""#\d+""", "#0").contains( - s"""== Parsed Logical Plan == - |GlobalLimit 1 - |+- LocalLimit 1 - | +- AnalysisBarrier - | +- Aggregate [a#0], [a#0, count(1) AS count#0L] - | +- Project [_1#0 AS a#0, _2#0 AS b#0] - | +- LocalRelation [_1#0, _2#0] - |""".stripMargin)) - } } From 7c76c83fe89f3e5aa28540fd76bdfc6016c35749 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 20 Jul 2018 09:13:00 -0700 Subject: [PATCH 06/12] bypass EliminateSubqueryAliases --- .../sql/catalyst/analysis/Analyzer.scala | 9 +++-- .../catalyst/plans/logical/LogicalPlan.scala | 33 ++++++++++++++----- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5a461bd07299a..58e2ee8abcf5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2376,9 +2376,12 @@ class Analyzer( */ object EliminateSubqueryAliases extends Rule[LogicalPlan] { // This is actually called in the beginning of the optimization phase, and as a result - // is using transformUp rather than resolveOperators. - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case SubqueryAlias(_, child) => child + // is using transformUp rather than resolveOperators. This is also often called in the + // + def apply(plan: LogicalPlan): LogicalPlan = LogicalPlan.allowInvokingTransformsInAnalyzer { + plan transformUp { + case SubqueryAlias(_, child) => child + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 581e6b942320b..c93be2daca849 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -23,20 +23,21 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats -import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils object LogicalPlan { - private val bypassTransformAnalyzerCheckFlag = new ThreadLocal[Int] { + + private val resolveOperatorDepth = new ThreadLocal[Int] { override def initialValue(): Int = 0 } - def bypassTransformAnalyzerCheck[T](p: => T): T = { - bypassTransformAnalyzerCheckFlag.set(bypassTransformAnalyzerCheckFlag.get() + 1) - try p finally { - bypassTransformAnalyzerCheckFlag.set(bypassTransformAnalyzerCheckFlag.get() - 1) + def allowInvokingTransformsInAnalyzer[T](f: => T): T = { + resolveOperatorDepth.set(resolveOperatorDepth.get + 1) + try f finally { + resolveOperatorDepth.set(resolveOperatorDepth.get - 1) } } } @@ -72,7 +73,7 @@ abstract class LogicalPlan */ def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { if (!analyzed) { - LogicalPlan.bypassTransformAnalyzerCheck { + LogicalPlan.allowInvokingTransformsInAnalyzer { val afterRuleOnChildren = mapChildren(_.resolveOperators(rule)) if (this fastEquals afterRuleOnChildren) { CurrentOrigin.withOrigin(origin) { @@ -92,7 +93,7 @@ abstract class LogicalPlan /** Similar to [[resolveOperators]], but does it top-down. */ def resolveOperatorsDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { if (!analyzed) { - LogicalPlan.bypassTransformAnalyzerCheck { + LogicalPlan.allowInvokingTransformsInAnalyzer { val afterRule = CurrentOrigin.withOrigin(origin) { rule.applyOrElse(this, identity[LogicalPlan]) } @@ -120,7 +121,7 @@ abstract class LogicalPlan } protected def assertNotAnalysisRule(): Unit = { - if (Utils.isTesting && LogicalPlan.bypassTransformAnalyzerCheckFlag.get == 0) { + if (Utils.isTesting && LogicalPlan.resolveOperatorDepth.get == 0) { if (Thread.currentThread.getStackTrace.exists(_.getClassName.contains("Analyzer"))) { val e = new RuntimeException("This method should not be called in the analyzer") e.printStackTrace() @@ -129,16 +130,30 @@ abstract class LogicalPlan } } + /** + * In analyzer, use [[resolveOperatorsDown()]] instead. If this is used in the analyzer, + * an exception will be thrown in test mode. It is however OK to call this function within + * the scope of a [[resolveOperatorsDown()]] call. + * @see [[TreeNode.transformDown()]]. + */ override def transformDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { assertNotAnalysisRule() super.transformDown(rule) } + /** + * Use [[resolveOperators()]] in the analyzer. + * @see [[TreeNode.transformUp()]] + */ override def transformUp(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { assertNotAnalysisRule() super.transformUp(rule) } + /** + * Use [[resolveExpressions()]] in the analyzer. + * @see [[QueryPlan.transformAllExpressions()]] + */ override def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { assertNotAnalysisRule() super.transformAllExpressions(rule) From 14ac09c11f623020028de08040f05f5a83c43763 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 20 Jul 2018 13:32:06 -0700 Subject: [PATCH 07/12] Added BooleanSimplification back --- .../spark/sql/catalyst/analysis/CheckAnalysis.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 056be4086174e..5f7916863d98c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -535,10 +535,8 @@ trait CheckAnalysis extends PredicateHelper { var foundNonEqualCorrelatedPred: Boolean = false - // Simplify the predicates before validating any unsupported correlation patterns - // in the plan. - // TODO(rxin): Why did this need to call BooleanSimplification??? - sub.foreachUp { + // Simplify the predicates before validating any unsupported correlation patterns in the plan. + LogicalPlan.allowInvokingTransformsInAnalyzer { BooleanSimplification(sub).foreachUp { // Whitelist operators allowed in a correlated subquery // There are 4 categories: // 1. Operators that are allowed anywhere in a correlated subquery, and, @@ -640,6 +638,6 @@ trait CheckAnalysis extends PredicateHelper { // are not allowed to have any correlated expressions. case p => failOnOuterReferenceInSubTree(p) - } + }} } } From 38980ad066d26327387673910e0dfd981102cab9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 20 Jul 2018 13:36:08 -0700 Subject: [PATCH 08/12] revert mistake --- .../test/resources/sql-tests/inputs/table-valued-functions.sql | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql index 26b7668b1fe44..72cd8ca9d8722 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -27,4 +27,3 @@ EXPLAIN select * from RaNgE(2); -- cross-join table valued functions EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3); -table-valued-functions.sql \ No newline at end of file From abfd0a8cd16c54fa19e7c66bfc2fb1f1c6b85a12 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Jul 2018 17:40:21 -0700 Subject: [PATCH 09/12] Switch to use thread local to make things go faster. --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 4 +- .../plans/logical/AnalysisHelper.scala | 194 ++++++++++++++++++ .../catalyst/plans/logical/LogicalPlan.scala | 128 +----------- 4 files changed, 199 insertions(+), 131 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 58e2ee8abcf5d..f6c6c7a742ac6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -102,7 +102,7 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } - def executeAndCheck(plan: LogicalPlan): LogicalPlan = { + def executeAndCheck(plan: LogicalPlan): LogicalPlan = AnalysisHelper.markInAnalyzer { val analyzed = execute(plan) try { checkAnalysis(analyzed) @@ -2378,7 +2378,7 @@ object EliminateSubqueryAliases extends Rule[LogicalPlan] { // This is actually called in the beginning of the optimization phase, and as a result // is using transformUp rather than resolveOperators. This is also often called in the // - def apply(plan: LogicalPlan): LogicalPlan = LogicalPlan.allowInvokingTransformsInAnalyzer { + def apply(plan: LogicalPlan): LogicalPlan = AnalysisHelper.allowInvokingTransformsInAnalyzer { plan transformUp { case SubqueryAlias(_, child) => child } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 5f7916863d98c..90ddc31e45356 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -371,7 +371,7 @@ trait CheckAnalysis extends PredicateHelper { case _ => } - plan.foreach(_.setAnalyzed()) + plan.setAnalyzed() } /** @@ -536,7 +536,7 @@ trait CheckAnalysis extends PredicateHelper { var foundNonEqualCorrelatedPred: Boolean = false // Simplify the predicates before validating any unsupported correlation patterns in the plan. - LogicalPlan.allowInvokingTransformsInAnalyzer { BooleanSimplification(sub).foreachUp { + AnalysisHelper.allowInvokingTransformsInAnalyzer { BooleanSimplification(sub).foreachUp { // Whitelist operators allowed in a correlated subquery // There are 4 categories: // 1. Operators that are allowed anywhere in a correlated subquery, and, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala new file mode 100644 index 0000000000000..15f2da608a8eb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.analysis.CheckAnalysis +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} +import org.apache.spark.util.Utils + + +/** + * [[AnalysisHelper]] defines some infrastructure for the query analyzer. In particular, in query + * analysis we don't want to repeatedly re-analyze sub-plans that have previously been analyzed. + * + * This trait defines a flag `analyzed` that can be set to true once analysis is done on the tree. + * This also provides a set of resolve methods that do not recurse down to sub-plans that have the + * analyzed flag set to true. + * + * The analyzer rules should use the various resolve methods, in lieu of the various transform + * methods defined in [[TreeNode]] and [[QueryPlan]]. + * + * To prevent accidental use of the transform methods, this trait also overrides the transform + * methods to throw exceptions in test mode, if they are used in the analyzer. + */ +trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => + + private var _analyzed: Boolean = false + + /** + * Recursively marks all nodes in this plan tree as analyzed. + * This should only be called by [[CheckAnalysis]]. + */ + private[catalyst] def setAnalyzed(): Unit = { + if (!_analyzed) { + _analyzed = true + children.foreach(_.setAnalyzed()) + } + } + + /** + * Returns true if this node and its children have already been gone through analysis and + * verification. Note that this is only an optimization used to avoid analyzing trees that + * have already been analyzed, and can be reset by transformations. + */ + def analyzed: Boolean = _analyzed + + /** + * Returns a copy of this node where `rule` has been recursively applied first to all of its + * children and then itself (post-order, bottom-up). When `rule` does not apply to a given node, + * it is left unchanged. This function is similar to `transformUp`, but skips sub-trees that + * have already been marked as analyzed. + * + * @param rule the function use to transform this nodes children + */ + def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + if (!analyzed) { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + val afterRuleOnChildren = mapChildren(_.resolveOperators(rule)) + if (self fastEquals afterRuleOnChildren) { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(self, identity[LogicalPlan]) + } + } else { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) + } + } + } + } else { + self + } + } + + /** Similar to [[resolveOperators]], but does it top-down. */ + def resolveOperatorsDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + if (!analyzed) { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + val afterRule = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(self, identity[LogicalPlan]) + } + + // Check if unchanged and then possibly return old copy to avoid gc churn. + if (self fastEquals afterRule) { + mapChildren(_.resolveOperatorsDown(rule)) + } else { + afterRule.mapChildren(_.resolveOperatorsDown(rule)) + } + } + } else { + self + } + } + + /** + * Recursively transforms the expressions of a tree, skipping nodes that have already + * been analyzed. + */ + def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = { + resolveOperators { + case p => p.transformExpressions(r) + } + } + + protected def assertNotAnalysisRule(): Unit = { + if (Utils.isTesting && + AnalysisHelper.inAnalyzer.get > 0 && + AnalysisHelper.resolveOperatorDepth.get == 0) { + val e = new RuntimeException("This method should not be called in the analyzer") + e.printStackTrace() + throw e + } + } + + /** + * In analyzer, use [[resolveOperatorsDown()]] instead. If this is used in the analyzer, + * an exception will be thrown in test mode. It is however OK to call this function within + * the scope of a [[resolveOperatorsDown()]] call. + * @see [[TreeNode.transformDown()]]. + */ + override def transformDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + assertNotAnalysisRule() + super.transformDown(rule) + } + + /** + * Use [[resolveOperators()]] in the analyzer. + * @see [[TreeNode.transformUp()]] + */ + override def transformUp(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + assertNotAnalysisRule() + super.transformUp(rule) + } + + /** + * Use [[resolveExpressions()]] in the analyzer. + * @see [[QueryPlan.transformAllExpressions()]] + */ + override def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + assertNotAnalysisRule() + super.transformAllExpressions(rule) + } + +} + + +object AnalysisHelper { + + /** + * A thread local to track whether we are in a resolveOperator call (for the purpose of analysis). + * This is an int because resolve* calls might be be nested (e.g. a rule might trigger another + * query compilation within the rule itself), so we are tracking the depth here. + */ + val resolveOperatorDepth: ThreadLocal[Int] = new ThreadLocal[Int] { + override def initialValue(): Int = 0 + } + + /** + * A thread local to track whether we are in the analysis phase of query compilation. This is an + * int rather than a boolean in case our analyzer recursively calls itself. + */ + val inAnalyzer: ThreadLocal[Int] = new ThreadLocal[Int] { + override def initialValue(): Int = 0 + } + + def allowInvokingTransformsInAnalyzer[T](f: => T): T = { + resolveOperatorDepth.set(resolveOperatorDepth.get + 1) + try f finally { + resolveOperatorDepth.set(resolveOperatorDepth.get - 1) + } + } + + def markInAnalyzer[T](f: => T): T = { + inAnalyzer.set(inAnalyzer.get + 1) + try f finally { + inAnalyzer.set(inAnalyzer.get - 1) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index c93be2daca849..0e4456ac0e6a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -23,142 +23,16 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats -import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils - - -object LogicalPlan { - - private val resolveOperatorDepth = new ThreadLocal[Int] { - override def initialValue(): Int = 0 - } - - def allowInvokingTransformsInAnalyzer[T](f: => T): T = { - resolveOperatorDepth.set(resolveOperatorDepth.get + 1) - try f finally { - resolveOperatorDepth.set(resolveOperatorDepth.get - 1) - } - } -} abstract class LogicalPlan extends QueryPlan[LogicalPlan] + with AnalysisHelper with LogicalPlanStats with QueryPlanConstraints with Logging { - private var _analyzed: Boolean = false - - /** - * Marks this plan as already analyzed. This should only be called by [[CheckAnalysis]]. - */ - private[catalyst] def setAnalyzed(): Unit = { _analyzed = true } - - /** - * Returns true if this node and its children have already been gone through analysis and - * verification. Note that this is only an optimization used to avoid analyzing trees that - * have already been analyzed, and can be reset by transformations. - */ - def analyzed: Boolean = _analyzed - - /** - * Returns a copy of this node where `rule` has been recursively applied first to all of its - * children and then itself (post-order, bottom-up). When `rule` does not apply to a given node, - * it is left unchanged. This function is similar to `transformUp`, but skips sub-trees that - * have already been marked as analyzed. - * - * @param rule the function use to transform this nodes children - */ - def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { - if (!analyzed) { - LogicalPlan.allowInvokingTransformsInAnalyzer { - val afterRuleOnChildren = mapChildren(_.resolveOperators(rule)) - if (this fastEquals afterRuleOnChildren) { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(this, identity[LogicalPlan]) - } - } else { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) - } - } - } - } else { - this - } - } - - /** Similar to [[resolveOperators]], but does it top-down. */ - def resolveOperatorsDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { - if (!analyzed) { - LogicalPlan.allowInvokingTransformsInAnalyzer { - val afterRule = CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(this, identity[LogicalPlan]) - } - - // Check if unchanged and then possibly return old copy to avoid gc churn. - if (this fastEquals afterRule) { - mapChildren(_.resolveOperatorsDown(rule)) - } else { - afterRule.mapChildren(_.resolveOperatorsDown(rule)) - } - } - } else { - this - } - } - - /** - * Recursively transforms the expressions of a tree, skipping nodes that have already - * been analyzed. - */ - def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = { - this resolveOperators { - case p => p.transformExpressions(r) - } - } - - protected def assertNotAnalysisRule(): Unit = { - if (Utils.isTesting && LogicalPlan.resolveOperatorDepth.get == 0) { - if (Thread.currentThread.getStackTrace.exists(_.getClassName.contains("Analyzer"))) { - val e = new RuntimeException("This method should not be called in the analyzer") - e.printStackTrace() - throw e - } - } - } - - /** - * In analyzer, use [[resolveOperatorsDown()]] instead. If this is used in the analyzer, - * an exception will be thrown in test mode. It is however OK to call this function within - * the scope of a [[resolveOperatorsDown()]] call. - * @see [[TreeNode.transformDown()]]. - */ - override def transformDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { - assertNotAnalysisRule() - super.transformDown(rule) - } - - /** - * Use [[resolveOperators()]] in the analyzer. - * @see [[TreeNode.transformUp()]] - */ - override def transformUp(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { - assertNotAnalysisRule() - super.transformUp(rule) - } - - /** - * Use [[resolveExpressions()]] in the analyzer. - * @see [[QueryPlan.transformAllExpressions()]] - */ - override def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { - assertNotAnalysisRule() - super.transformAllExpressions(rule) - } - /** Returns true if this subtree has data from a streaming data source. */ def isStreaming: Boolean = children.exists(_.isStreaming == true) From 75fb1145fa84725fb931752906ece34ae6028290 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Jul 2018 18:01:31 -0700 Subject: [PATCH 10/12] Test cases --- .../plans/logical/AnalysisHelper.scala | 8 +- .../sql/catalyst/plans/LogicalPlanSuite.scala | 5 +- .../plans/logical/AnalysisHelperSuite.scala | 159 ++++++++++++++++++ 3 files changed, 164 insertions(+), 8 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala index 15f2da608a8eb..039acc1ea4fa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala @@ -121,9 +121,7 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => if (Utils.isTesting && AnalysisHelper.inAnalyzer.get > 0 && AnalysisHelper.resolveOperatorDepth.get == 0) { - val e = new RuntimeException("This method should not be called in the analyzer") - e.printStackTrace() - throw e + throw new RuntimeException("This method should not be called in the analyzer") } } @@ -166,7 +164,7 @@ object AnalysisHelper { * This is an int because resolve* calls might be be nested (e.g. a rule might trigger another * query compilation within the rule itself), so we are tracking the depth here. */ - val resolveOperatorDepth: ThreadLocal[Int] = new ThreadLocal[Int] { + private val resolveOperatorDepth: ThreadLocal[Int] = new ThreadLocal[Int] { override def initialValue(): Int = 0 } @@ -174,7 +172,7 @@ object AnalysisHelper { * A thread local to track whether we are in the analysis phase of query compilation. This is an * int rather than a boolean in case our analyzer recursively calls itself. */ - val inAnalyzer: ThreadLocal[Int] = new ThreadLocal[Int] { + private val inAnalyzer: ThreadLocal[Int] = new ThreadLocal[Int] { override def initialValue(): Int = 0 } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index a3b9d52ddf719..aaab3ff1bf128 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Coalesce, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType /** - * This suite is used to test [[LogicalPlan]]'s `transformUp/transformDown` plus analysis barrier - * and make sure it can correctly skip sub-trees that have already been analyzed. + * This suite is used to test [[LogicalPlan]]'s `transformUp/transformDown`. */ class LogicalPlanSuite extends SparkFunSuite { private var invocationCount = 0 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala new file mode 100644 index 0000000000000..9100e10ca0c09 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, Literal, NamedExpression} + + +class AnalysisHelperSuite extends SparkFunSuite { + + private var invocationCount = 0 + private val function: PartialFunction[LogicalPlan, LogicalPlan] = { + case p: Project => + invocationCount += 1 + p + } + + private val exprFunction: PartialFunction[Expression, Expression] = { + case e: Literal => + invocationCount += 1 + Literal.TrueLiteral + } + + private def projectExprs: Seq[NamedExpression] = Alias(Literal.TrueLiteral, "A")() :: Nil + + test("setAnalyze is recursive") { + val plan = Project(Nil, LocalRelation()) + plan.setAnalyzed() + assert(plan.find(!_.analyzed).isEmpty) + } + + test("resolveOperator runs on operators recursively") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.resolveOperators(function) + assert(invocationCount === 2) + } + + test("resolveOperatorsDown runs on operators recursively") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.resolveOperatorsDown(function) + assert(invocationCount === 2) + } + + test("resolveExpressions runs on operators recursively") { + invocationCount = 0 + val plan = Project(projectExprs, Project(projectExprs, LocalRelation())) + plan.resolveExpressions(exprFunction) + assert(invocationCount === 2) + } + + test("resolveOperator skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.setAnalyzed() + plan.resolveOperators(function) + assert(invocationCount === 0) + } + + test("resolveOperatorsDown skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.setAnalyzed() + plan.resolveOperatorsDown(function) + assert(invocationCount === 0) + } + + test("resolveExpressions skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(projectExprs, Project(projectExprs, LocalRelation())) + plan.setAnalyzed() + plan.resolveExpressions(exprFunction) + assert(invocationCount === 0) + } + + test("resolveOperator skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(Nil, LocalRelation()) + val plan2 = Project(Nil, plan1) + plan1.setAnalyzed() + plan2.resolveOperators(function) + assert(invocationCount === 1) + } + + test("resolveOperatorsDown skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(Nil, LocalRelation()) + val plan2 = Project(Nil, plan1) + plan1.setAnalyzed() + plan2.resolveOperatorsDown(function) + assert(invocationCount === 1) + } + + test("resolveExpressions skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(projectExprs, LocalRelation()) + val plan2 = Project(projectExprs, plan1) + plan1.setAnalyzed() + plan2.resolveExpressions(exprFunction) + assert(invocationCount === 1) + } + + test("do not allow transform in analyzer") { + val plan = Project(Nil, LocalRelation()) + // These should be OK since we are not in the analzyer + plan.transform { case p: Project => p } + plan.transformUp { case p: Project => p } + plan.transformDown { case p: Project => p } + plan.transformAllExpressions { case lit: Literal => lit } + + // The following should fail in the analyzer scope + AnalysisHelper.markInAnalyzer { + intercept[RuntimeException] { plan.transform { case p: Project => p } } + intercept[RuntimeException] { plan.transformUp { case p: Project => p } } + intercept[RuntimeException] { plan.transformDown { case p: Project => p } } + intercept[RuntimeException] { plan.transformAllExpressions { case lit: Literal => lit } } + } + } + + test("allow transform in resolveOperators in the analyzer") { + val plan = Project(Nil, LocalRelation()) + AnalysisHelper.markInAnalyzer { + plan.resolveOperators { case p: Project => p.transform { case p: Project => p } } + plan.resolveOperatorsDown { case p: Project => p.transform { case p: Project => p } } + plan.resolveExpressions { case lit: Literal => + Project(Nil, LocalRelation()).transform { case p: Project => p } + lit + } + } + } + + test("allow transform with allowInvokingTransformsInAnalyzer in the analyzer") { + val plan = Project(Nil, LocalRelation()) + AnalysisHelper.markInAnalyzer { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + plan.transform { case p: Project => p } + plan.transformUp { case p: Project => p } + plan.transformDown { case p: Project => p } + plan.transformAllExpressions { case lit: Literal => lit } + } + } + } +} From f2f1a97e447a41e8b9b6c094376d32b32af00991 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Jul 2018 18:03:02 -0700 Subject: [PATCH 11/12] Remove one TODO --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f6c6c7a742ac6..15dd64e7e7ecc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -234,7 +234,6 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => - // TODO(rxin): Check with Herman whether the next line is OK. child.resolveOperators { case p => p.transformExpressions { case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => From 7995272fe2fd929077ec056a53d26ee6730b685d Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Thu, 26 Jul 2018 12:01:50 -0700 Subject: [PATCH 12/12] add a dummy fix --- .../src/main/scala/org/apache/spark/sql/DataFrameWriter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3a828ce6ba694..3c9e743106260 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -351,7 +351,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def assertNotPartitioned(operation: String): Unit = { if (partitioningColumns.isDefined) { - throw new AnalysisException( s"'$operation' does not support partitioning") + throw new AnalysisException(s"'$operation' does not support partitioning") } }