From c2fcaa8e488d12419c7b7c5032ccadab38f20b68 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 9 Jan 2016 19:21:14 -0800 Subject: [PATCH 01/17] window function: Sorting columns are not in Project --- .../sql/catalyst/analysis/Analyzer.scala | 27 ++++++++++++- .../sql/hive/execution/SQLQuerySuite.scala | 40 ++++++++++++++++++- 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8a33af8207350..87ab08d1f042b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -523,14 +523,37 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case s @ Sort(ordering, global, p @ Project(projectList, child)) if !s.resolved && p.resolved => - val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child) + val (newOrdering, missing, newChild): (Seq[SortOrder], Seq[Attribute], LogicalPlan) = + child match { + case Project( + projectListAboveWindow, + w @ Window( + projectListInWindow, + windowExpressions, + partitionSpec, + orderSpec, + pW @ Project( + projectListBelowWindow, + childBelowWindow))) => + val (newOrdering, missingAttrs) = + resolveAndFindMissing(ordering, pW, childBelowWindow) + (newOrdering, missingAttrs, + Project(projectListAboveWindow ++ missingAttrs, + Window( + projectListInWindow ++ missingAttrs, + windowExpressions, partitionSpec, orderSpec, + Project(projectListBelowWindow ++ missingAttrs, childBelowWindow)))) + case _ => + val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, p, child) + (newOrdering, missingAttrs, child) + } // If this rule was not a no-op, return the transformed plan, otherwise return the original. if (missing.nonEmpty) { // Add missing attributes and then project them away after the sort. Project(p.output, Sort(newOrdering, global, - Project(projectList ++ missing, child))) + Project(projectList ++ missing, newChild))) } else { logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") s // Nothing we can do here. Return original plan. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 593fac2c32817..cf93d3cd8e1b8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -773,7 +773,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { """.stripMargin), (2 to 6).map(i => Row(i))) } - test("window function: udaf with aggregate expressin") { + test("window function: udaf with aggregate expression") { val data = Seq( WindowData(1, "a", 5), WindowData(2, "a", 6), @@ -964,6 +964,44 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("window function: Sorting columns are not in Project") { + val data = Seq( + WindowData(1, "d", 10), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 10) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + checkAnswer( + sql("select month, product, sum(product + 1) over() from windowData order by area"), + Seq( + (2, 6, 56), + (3, 7, 56), + (4, 8, 56), + (5, 9, 56), + (6, 10, 56), + (1, 10, 56) + ).map(i => Row(i._1, i._2, i._3))) + + checkAnswer( + sql( + """ + |select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1 + |from (select month, area, product as p, 1 as tmp1 from windowData) tmp order by p + """.stripMargin), + Seq( + ("a", 2), + ("b", 2), + ("b", 3), + ("c", 2), + ("d", 2), + ("c", 3) + ).map(i => Row(i._1, i._2))) + } + test("window function: multiple window expressions in a single expression") { val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") nums.registerTempTable("nums") From 5ca463035bc6eaebd15e7cf332faeea157e5593e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 9 Jan 2016 19:30:58 -0800 Subject: [PATCH 02/17] style fix. --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 87ab08d1f042b..79811f9f13604 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -527,7 +527,7 @@ class Analyzer( child match { case Project( projectListAboveWindow, - w @ Window( + Window( projectListInWindow, windowExpressions, partitionSpec, From da6baf25488767ce6e73538b03f9195bba92b84e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 9 Jan 2016 22:23:48 -0800 Subject: [PATCH 03/17] code cleaning and address comments. --- .../sql/catalyst/analysis/Analyzer.scala | 40 ++++++++----------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 79811f9f13604..268f4f5c6ca0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -521,39 +521,31 @@ class Analyzer( */ object ResolveSortReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case s @ Sort(ordering, global, p @ Project(projectList, child)) - if !s.resolved && p.resolved => + case s @ Sort(_, _, p @ Project(_, _)) if !s.resolved && p.resolved => val (newOrdering, missing, newChild): (Seq[SortOrder], Seq[Attribute], LogicalPlan) = - child match { - case Project( - projectListAboveWindow, - Window( - projectListInWindow, - windowExpressions, - partitionSpec, - orderSpec, - pW @ Project( - projectListBelowWindow, - childBelowWindow))) => - val (newOrdering, missingAttrs) = - resolveAndFindMissing(ordering, pW, childBelowWindow) + p.child match { + // Case 1: when WINDOW functions are used in the SELECT clause. + // Example: SELECT sum(col1) OVER() FROM table1 ORDER BY col2 + case p1 @ Project(_, w @ Window(_, _, _, _, p2: Project)) => + val (newOrdering, missingAttrs) = resolveAndFindMissing(s.order, p2, p2.child) (newOrdering, missingAttrs, - Project(projectListAboveWindow ++ missingAttrs, - Window( - projectListInWindow ++ missingAttrs, - windowExpressions, partitionSpec, orderSpec, - Project(projectListBelowWindow ++ missingAttrs, childBelowWindow)))) + Project(p1.projectList ++ missingAttrs, + Window(w.projectList ++ missingAttrs, + w.windowExpressions, w.partitionSpec, w.orderSpec, + Project(p2.projectList ++ missingAttrs, p2.child)))) + // Case 2: the other cases + // Example: SELECT col1 FROM table1 ORDER BY col2 case _ => - val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, p, child) - (newOrdering, missingAttrs, child) + val (newOrdering, missingAttrs) = resolveAndFindMissing(s.order, p, p.child) + (newOrdering, missingAttrs, p.child) } // If this rule was not a no-op, return the transformed plan, otherwise return the original. if (missing.nonEmpty) { // Add missing attributes and then project them away after the sort. Project(p.output, - Sort(newOrdering, global, - Project(projectList ++ missing, newChild))) + Sort(newOrdering, s.global, + Project(p.projectList ++ missing, newChild))) } else { logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") s // Nothing we can do here. Return original plan. From d164342747502b09686c1802cf9d24d8ed4c899e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 12 Jan 2016 22:15:31 -0800 Subject: [PATCH 04/17] address comments. --- .../sql/catalyst/analysis/Analyzer.scala | 97 ++++++++++++------- .../sql/hive/execution/SQLQuerySuite.scala | 15 +++ 2 files changed, 78 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 268f4f5c6ca0c..36bc7af385aa9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer +import scala.annotation.tailrec import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} @@ -73,6 +74,7 @@ class Analyzer( ResolveGroupingAnalytics :: ResolvePivot :: ResolveUpCast :: + ResolveAggregateFunctions :: ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: @@ -81,7 +83,6 @@ class Analyzer( ResolveWindowFrame :: ExtractWindowExpressions :: GlobalAggregates :: - ResolveAggregateFunctions :: DistinctAggregationRewriter(conf) :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), @@ -440,7 +441,7 @@ class Analyzer( } // When resolve `SortOrder`s in Sort based on child, don't report errors as - // we still have chance to resolve it based on grandchild + // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => val newOrdering = resolveSortOrders(ordering, child, throws = false) Sort(newOrdering, global, child) @@ -521,53 +522,81 @@ class Analyzer( */ object ResolveSortReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case s @ Sort(_, _, p @ Project(_, _)) if !s.resolved && p.resolved => - val (newOrdering, missing, newChild): (Seq[SortOrder], Seq[Attribute], LogicalPlan) = - p.child match { - // Case 1: when WINDOW functions are used in the SELECT clause. - // Example: SELECT sum(col1) OVER() FROM table1 ORDER BY col2 - case p1 @ Project(_, w @ Window(_, _, _, _, p2: Project)) => - val (newOrdering, missingAttrs) = resolveAndFindMissing(s.order, p2, p2.child) - (newOrdering, missingAttrs, - Project(p1.projectList ++ missingAttrs, - Window(w.projectList ++ missingAttrs, - w.windowExpressions, w.partitionSpec, w.orderSpec, - Project(p2.projectList ++ missingAttrs, p2.child)))) - // Case 2: the other cases - // Example: SELECT col1 FROM table1 ORDER BY col2 - case _ => - val (newOrdering, missingAttrs) = resolveAndFindMissing(s.order, p, p.child) - (newOrdering, missingAttrs, p.child) + case s @ Sort(_, _, child) if !s.resolved && child.resolved => + val (newOrdering, missingAttrs, missingPlans) = + collectResolvedMissingAttrs(s.order, child, Seq.empty[LogicalPlan]) + + if (missingAttrs.isEmpty) { + logDebug(s"Failed to find $missingAttrs in ${child.output.mkString(", ")}") + s // Nothing we can do here. Return original plan. } + else { + val missingPlanSet = missingPlans.toSet + val newChild = child transform { + case p: Project if missingPlanSet.contains(p) => + p.copy(projectList = p.projectList ++ missingAttrs) + case w: Window if missingPlanSet.contains(w) => + w.copy(projectList = w.projectList ++ missingAttrs) + case a: Aggregate if missingPlanSet.contains(a) => + // Grouping expressions could already have the missing attributes. + // Do not add the duplicate attributes. + val newGroupExpressions = a.groupingExpressions ++ missingAttrs.filterNot( + attr => a.groupingExpressions.exists(_.semanticEquals(attr))) + a.copy(aggregateExpressions = a.aggregateExpressions ++ missingAttrs, + groupingExpressions = newGroupExpressions) + case o => o + } - // If this rule was not a no-op, return the transformed plan, otherwise return the original. - if (missing.nonEmpty) { // Add missing attributes and then project them away after the sort. - Project(p.output, - Sort(newOrdering, s.global, - Project(p.projectList ++ missing, newChild))) - } else { - logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") - s // Nothing we can do here. Return original plan. + Project(child.output, + Sort(newOrdering, s.global, newChild)) } } /** - * Given a child and a grandchild that are present beneath a sort operator, try to resolve - * the sort ordering and returns it with a list of attributes that are missing from the - * child but are present in the grandchild. + * Traverse the tree until resolving the sorting attributes and returns it + * with a list of traversed operators that miss the sorting attributes. + */ + @tailrec + private def collectResolvedMissingAttrs( + ordering: Seq[SortOrder], + plan: LogicalPlan, + missingPlans: Seq[LogicalPlan]): (Seq[SortOrder], Seq[Attribute], Seq[LogicalPlan]) = { + plan match { + // Subquery does nothing. We can simply skip it. + case s: Subquery => + collectResolvedMissingAttrs(ordering, s.child, missingPlans) + // Only Windows, Project and Aggregate have projectList-like attribute. + // TODO: when the other operators have it, we should add a support too. + case un: UnaryNode + if un.isInstanceOf[Project] || un.isInstanceOf[Window] || un.isInstanceOf[Aggregate] => + val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child) + // If missingAttrs is non empty, that means we got it and return it; + // Otherwise, continue to traverse the tree. + if (missingAttrs.nonEmpty) (newOrdering, missingAttrs, missingPlans :+ un) + else collectResolvedMissingAttrs(ordering, un.child, missingPlans :+ un) + // If hitting the other unsupported operators, we are unable to resolve it + // and thus stop traversing the plan tree. + case other => + (Seq.empty[SortOrder], Seq.empty[Attribute], Seq.empty[LogicalPlan]) + } + } + + /** + * Try to resolve the sort ordering and returns it with a list of attributes that are missing + * from the child but are present in the grandchild. */ def resolveAndFindMissing( ordering: Seq[SortOrder], - child: LogicalPlan, - grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { - val newOrdering = resolveSortOrders(ordering, grandchild, throws = true) + plan: LogicalPlan, + child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { + val newOrdering = resolveSortOrders(ordering, child, throws = false) // Construct a set that contains all of the attributes that we need to evaluate the // ordering. val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved) // Figure out which ones are missing from the projection, so that we can add them and // remove them after the sort. - val missingInProject = requiredAttributes -- child.output + val missingInProject = requiredAttributes -- plan.outputSet // It is important to return the new SortOrders here, instead of waiting for the standard // resolving process as adding attributes to the project below can actually introduce // ambiguity that was not present before. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index cf93d3cd8e1b8..3f546a7753047 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1000,6 +1000,21 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ("d", 2), ("c", 3) ).map(i => Row(i._1, i._2))) + + checkAnswer( + sql( + """ + |select area, rank() over (partition by area order by month) as c1 + |from windowData group by product, area, month order by product, area + """.stripMargin), + Seq( + ("a", 1), + ("b", 1), + ("b", 2), + ("c", 1), + ("c", 2), + ("d", 1) + ).map(i => Row(i._1, i._2))) } test("window function: multiple window expressions in a single expression") { From 27fcaa5ad6a3b4228ef4fc46b963c1e818d2f5c4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 13 Jan 2016 00:30:12 -0800 Subject: [PATCH 05/17] address comments. --- .../sql/catalyst/analysis/Analyzer.scala | 59 +++++++++++-------- 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 36bc7af385aa9..ba89413af135f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -523,28 +523,41 @@ class Analyzer( object ResolveSortReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case s @ Sort(_, _, child) if !s.resolved && child.resolved => - val (newOrdering, missingAttrs, missingPlans) = - collectResolvedMissingAttrs(s.order, child, Seq.empty[LogicalPlan]) + val (newOrdering, missingResolvableAttrs) = collectResolvedMissingAttrs(s.order, child) - if (missingAttrs.isEmpty) { - logDebug(s"Failed to find $missingAttrs in ${child.output.mkString(", ")}") + if (missingResolvableAttrs.isEmpty) { + val unresolvableAttrs = s.order.filterNot(_.resolved) + logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}") s // Nothing we can do here. Return original plan. } else { - val missingPlanSet = missingPlans.toSet + var stop: Boolean = false + var missingAttrs: Seq[Attribute] = missingResolvableAttrs val newChild = child transform { - case p: Project if missingPlanSet.contains(p) => - p.copy(projectList = p.projectList ++ missingAttrs) - case w: Window if missingPlanSet.contains(w) => - w.copy(projectList = w.projectList ++ missingAttrs) - case a: Aggregate if missingPlanSet.contains(a) => + case p: Project if !stop && missingAttrs.nonEmpty => + val newList = p.projectList ++ missingAttrs + missingAttrs = missingAttrs.filterNot( + attr => p.child.outputSet.exists(_.semanticEquals(attr))) + p.copy(projectList = newList) + case w: Window if !stop && missingAttrs.nonEmpty => + val newList = w.projectList ++ missingAttrs + missingAttrs = missingAttrs.filterNot( + attr => w.child.outputSet.exists(_.semanticEquals(attr))) + w.copy(projectList = newList) + case a: Aggregate if !stop && missingAttrs.nonEmpty => // Grouping expressions could already have the missing attributes. // Do not add the duplicate attributes. val newGroupExpressions = a.groupingExpressions ++ missingAttrs.filterNot( attr => a.groupingExpressions.exists(_.semanticEquals(attr))) - a.copy(aggregateExpressions = a.aggregateExpressions ++ missingAttrs, - groupingExpressions = newGroupExpressions) - case o => o + val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs + missingAttrs = missingAttrs.filterNot( + attr => a.child.outputSet.exists(_.semanticEquals(attr))) + a.copy(groupingExpressions = newGroupExpressions, + aggregateExpressions = newAggregateExpressions) + case s: Subquery if !stop && missingAttrs.nonEmpty => s + case o => + stop = true + o } // Add missing attributes and then project them away after the sort. @@ -554,37 +567,33 @@ class Analyzer( } /** - * Traverse the tree until resolving the sorting attributes and returns it - * with a list of traversed operators that miss the sorting attributes. + * Traverse the tree until resolving the sorting attributes. */ @tailrec private def collectResolvedMissingAttrs( ordering: Seq[SortOrder], - plan: LogicalPlan, - missingPlans: Seq[LogicalPlan]): (Seq[SortOrder], Seq[Attribute], Seq[LogicalPlan]) = { + plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { plan match { // Subquery does nothing. We can simply skip it. - case s: Subquery => - collectResolvedMissingAttrs(ordering, s.child, missingPlans) + case s: Subquery => collectResolvedMissingAttrs(ordering, s.child) // Only Windows, Project and Aggregate have projectList-like attribute. // TODO: when the other operators have it, we should add a support too. case un: UnaryNode - if un.isInstanceOf[Project] || un.isInstanceOf[Window] || un.isInstanceOf[Aggregate] => + if un.isInstanceOf[Project] || un.isInstanceOf[Window] || un.isInstanceOf[Aggregate] => val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child) // If missingAttrs is non empty, that means we got it and return it; // Otherwise, continue to traverse the tree. - if (missingAttrs.nonEmpty) (newOrdering, missingAttrs, missingPlans :+ un) - else collectResolvedMissingAttrs(ordering, un.child, missingPlans :+ un) + if (missingAttrs.nonEmpty) (newOrdering, missingAttrs) + else collectResolvedMissingAttrs(ordering, un.child) // If hitting the other unsupported operators, we are unable to resolve it // and thus stop traversing the plan tree. - case other => - (Seq.empty[SortOrder], Seq.empty[Attribute], Seq.empty[LogicalPlan]) + case other => (Seq.empty[SortOrder], Seq.empty[Attribute]) } } /** * Try to resolve the sort ordering and returns it with a list of attributes that are missing - * from the child but are present in the grandchild. + * from the plan but are present in the child. */ def resolveAndFindMissing( ordering: Seq[SortOrder], From 7fc98e49a26fd03f398b2241b4cfd19e969b770e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 16 Jan 2016 21:03:23 -0800 Subject: [PATCH 06/17] added a support to more operators. --- .../sql/catalyst/analysis/Analyzer.scala | 96 +++++++++++++------ .../apache/spark/sql/DataFrameJoinSuite.scala | 10 ++ .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++ 3 files changed, 81 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ba89413af135f..48ff1eddd320b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.annotation.tailrec @@ -24,6 +25,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef @@ -523,7 +525,8 @@ class Analyzer( object ResolveSortReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case s @ Sort(_, _, child) if !s.resolved && child.resolved => - val (newOrdering, missingResolvableAttrs) = collectResolvedMissingAttrs(s.order, child) + val (newOrdering, missingResolvableAttrs) = + collectResolvableMissingAttrs(s.order, plans = mutable.Queue(child)) if (missingResolvableAttrs.isEmpty) { val unresolvableAttrs = s.order.filterNot(_.resolved) @@ -531,33 +534,30 @@ class Analyzer( s // Nothing we can do here. Return original plan. } else { - var stop: Boolean = false - var missingAttrs: Seq[Attribute] = missingResolvableAttrs - val newChild = child transform { - case p: Project if !stop && missingAttrs.nonEmpty => - val newList = p.projectList ++ missingAttrs - missingAttrs = missingAttrs.filterNot( - attr => p.child.outputSet.exists(_.semanticEquals(attr))) - p.copy(projectList = newList) - case w: Window if !stop && missingAttrs.nonEmpty => - val newList = w.projectList ++ missingAttrs - missingAttrs = missingAttrs.filterNot( - attr => w.child.outputSet.exists(_.semanticEquals(attr))) - w.copy(projectList = newList) - case a: Aggregate if !stop && missingAttrs.nonEmpty => + // Transform the whole tree in post-order. Add into the self's outputSet + // all the children attributes that are part of missingResolvableAttrs. + // Assumption: all the conflicting attributes between left and right have been resolved + val newChild = child transformUp { + case p: Project => + val missingAttrs = + findNotResolvedMissingAttrs(p.outputSet, p.child.outputSet, missingResolvableAttrs) + p.copy(projectList = p.projectList ++ missingAttrs) + case w: Window => + val missingAttrs = + findNotResolvedMissingAttrs(w.outputSet, w.child.outputSet, missingResolvableAttrs) + w.copy(projectList = w.projectList ++ missingAttrs) + case a: Aggregate => // Grouping expressions could already have the missing attributes. // Do not add the duplicate attributes. - val newGroupExpressions = a.groupingExpressions ++ missingAttrs.filterNot( - attr => a.groupingExpressions.exists(_.semanticEquals(attr))) - val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs - missingAttrs = missingAttrs.filterNot( - attr => a.child.outputSet.exists(_.semanticEquals(attr))) + val newGroupExpressions = a.groupingExpressions ++ + findNotResolvedMissingAttrs( + a.groupingExpressions, a.child.outputSet, missingResolvableAttrs) + val newAggregateExpressions = a.aggregateExpressions ++ + findNotResolvedMissingAttrs( + a.aggregateExpressions, a.child.outputSet, missingResolvableAttrs) a.copy(groupingExpressions = newGroupExpressions, aggregateExpressions = newAggregateExpressions) - case s: Subquery if !stop && missingAttrs.nonEmpty => s - case o => - stop = true - o + case o => o } // Add missing attributes and then project them away after the sort. @@ -566,16 +566,33 @@ class Analyzer( } } + private def findNotResolvedMissingAttrs( + outputSet: AttributeSet, + childOutputSet: AttributeSet, + missingAttrs: Seq[Attribute]): Seq[Attribute] = { + val resolvedAttrs = + missingAttrs.filter(attr => childOutputSet.exists(_.semanticEquals(attr))) + resolvedAttrs.filterNot(attr => outputSet.exists(_.semanticEquals(attr))) + } + + private def findNotResolvedMissingAttrs( + outputSet: Seq[Expression], + childOutputSet: AttributeSet, + missingAttrs: Seq[Attribute]): Seq[Attribute] = { + val resolvedAttrs = + missingAttrs.filter(attr => childOutputSet.exists(_.semanticEquals(attr))) + resolvedAttrs.filterNot(attr => outputSet.exists(_.semanticEquals(attr))) + } + /** - * Traverse the tree until resolving the sorting attributes. + * Traverse the tree until resolving the sorting attributes + * Return all the resolvable missing sorting attributes */ @tailrec - private def collectResolvedMissingAttrs( + private def collectResolvableMissingAttrs( ordering: Seq[SortOrder], - plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { - plan match { - // Subquery does nothing. We can simply skip it. - case s: Subquery => collectResolvedMissingAttrs(ordering, s.child) + plans: mutable.Queue[LogicalPlan]): (Seq[SortOrder], Seq[Attribute]) = { + plans.dequeue() match { // Only Windows, Project and Aggregate have projectList-like attribute. // TODO: when the other operators have it, we should add a support too. case un: UnaryNode @@ -584,7 +601,24 @@ class Analyzer( // If missingAttrs is non empty, that means we got it and return it; // Otherwise, continue to traverse the tree. if (missingAttrs.nonEmpty) (newOrdering, missingAttrs) - else collectResolvedMissingAttrs(ordering, un.child) + else { + plans.enqueue(un.child) + collectResolvableMissingAttrs(ordering, plans) + } + // Skip the UnaryNode whose output is the same as their child's output + case un: UnaryNode if un.child.output == un.output => + plans.enqueue(un.child) + collectResolvableMissingAttrs(ordering, plans) + case join @ Join(left, right, joinType, _) => + joinType match { + case _ @ (Inner | LeftOuter | RightOuter | FullOuter) => + plans.enqueue(left, right) + collectResolvableMissingAttrs(ordering, plans) + // If we support LeftAnti, we should add it here + case _ @ LeftSemi => + plans.enqueue(left) + collectResolvableMissingAttrs(ordering, plans) + } // If hitting the other unsupported operators, we are unable to resolve it // and thus stop traversing the plan tree. case other => (Seq.empty[SortOrder], Seq.empty[Attribute]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 39a65413bd592..11d14f71bb48d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -42,6 +42,16 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) } + test("join - sorted columns not in join's outputSet") { + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str_sort").as('df1) + val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df2) + + checkAnswer( + df.join(df2, $"df1.int" === $"df2.int", "outer").select($"df1.int", $"df2.int2") + .orderBy('str_sort.asc, 'str.asc), + Row(null, 6) :: Row(1, 3) :: Row(3, null) :: Nil) + } + test("join - join using multiple columns and specifying join type") { val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str") val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 983dfbdedeefe..8937d9b87779f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -898,6 +898,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(expected === actual) } + test("Sorting columns are not in Filter and Project") { + checkAnswer( + upperCaseData.filter('N > 1).select('N).filter('N < 6).orderBy('L.asc), + Row(2) :: Row(3) :: Row(4) :: Row(5) :: Nil) + } + test("SPARK-9323: DataFrame.orderBy should support nested column name") { val df = sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": 1}}""" :: Nil)) From 522626bbd483054f441d2ca49bc06512901258ea Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 16 Jan 2016 21:25:56 -0800 Subject: [PATCH 07/17] style fix. --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7adba18a19214..c471b63a01d86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.analysis +import scala.annotation.tailrec import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.annotation.tailrec import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} From 26945fa63809a8671461404eb2e661e1605dc196 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 16 Jan 2016 23:14:38 -0800 Subject: [PATCH 08/17] fixed the test case that might fail sometimes due to the sorted values are duplicate --- .../sql/hive/execution/SQLQuerySuite.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 0cdbc1df8a11c..0572635cc04b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -973,19 +973,19 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { WindowData(3, "b", 7), WindowData(4, "b", 8), WindowData(5, "c", 9), - WindowData(6, "c", 10) + WindowData(6, "c", 11) ) sparkContext.parallelize(data).toDF().registerTempTable("windowData") checkAnswer( sql("select month, product, sum(product + 1) over() from windowData order by area"), Seq( - (2, 6, 56), - (3, 7, 56), - (4, 8, 56), - (5, 9, 56), - (6, 10, 56), - (1, 10, 56) + (2, 6, 57), + (3, 7, 57), + (4, 8, 57), + (5, 9, 57), + (6, 11, 57), + (1, 10, 57) ).map(i => Row(i._1, i._2, i._3))) checkAnswer( @@ -1014,8 +1014,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ("b", 1), ("b", 2), ("c", 1), - ("c", 2), - ("d", 1) + ("d", 1), + ("c", 2) ).map(i => Row(i._1, i._2))) } From bd3ed13b9e78d59274cda6c243acc5e704bb2821 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 17 Jan 2016 10:21:54 -0800 Subject: [PATCH 09/17] added test cases --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 64 +++++++++++++++++++ .../sql/catalyst/analysis/TestRelations.scala | 6 ++ 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c471b63a01d86..0b8a3641d7e02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -606,7 +606,7 @@ class Analyzer( collectResolvableMissingAttrs(ordering, plans) } // Skip the UnaryNode whose output is the same as their child's output - case un: UnaryNode if un.child.output == un.output => + case un: UnaryNode if un.inputSet == un.outputSet => plans.enqueue(un.child) collectResolvableMissingAttrs(ordering, plans) case join @ Join(left, right, joinType, _) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 975cd87d090e4..c1e216bafac74 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -76,6 +76,70 @@ class AnalysisSuite extends AnalysisTest { caseSensitive = false) } + test("resolve sort references - filter") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + + val plan1 = testRelation2 + .where('a > 0).select('a, 'b) + .where('b > 0).select('a) + .sortBy('b.asc, 'c.desc) + val expected1 = testRelation2 + .where(a.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, c) + .where(b.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, c) + .sortBy(b.asc, c.desc) + .select(a, b).select(a) + checkAnalysis(plan1, expected1) + + val plan2 = testRelation2 + .where('a > 0).select('a) + .where('a > 0).select('a) + .sortBy('b.asc, 'c.desc) + val expected2 = testRelation2 + .where(a.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, c) + .where(a.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, c) + .sortBy(b.asc, c.desc) + .select(a) + checkAnalysis(plan2, expected2) + } + + test("resolve sort references - join") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + + val f = testRelation3.output(1) + val h = testRelation3.output(3) + + val plan1 = testRelation2.join(testRelation3) + .where('a > 0).select('a, 'b) + .sortBy('c.desc, 'h.asc) + val expected1 = testRelation2.join(testRelation3) + .where(a.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, c, h) + .sortBy(c.desc, h.asc) + .select(a, b) + checkAnalysis(plan1, expected1) + + val plan2 = testRelation2.select('a, 'b).join(testRelation3) + .where('a > 0).select('a, 'b) + .sortBy('c.desc, 'h.asc) + val expected2 = testRelation2.select(a, b, c).join(testRelation3) + .where(a.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, h, c) + .sortBy(c.desc, h.asc) + .select(a, b, h).select(a, b) + checkAnalysis(plan2, expected2) + + val plan3 = testRelation2.select('a, 'b).join(testRelation3.select('f)) + .where('a > 0).select('a, 'b) + .sortBy('c.desc, 'h.asc) + val expected3 = testRelation2.select(a, b, c).join(testRelation3.select(f, h)) + .where(a.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, c, h) + .sortBy(c.desc, h.asc) + .select(a, b, c).select(a, b) + checkAnalysis(plan3, expected3) + } + test("resolve relations") { assertAnalysisError( UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq("Table not found: tAbLe")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala index bc07b609a3413..3741a6ba95a86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala @@ -31,6 +31,12 @@ object TestRelations { AttributeReference("d", DecimalType(10, 2))(), AttributeReference("e", ShortType)()) + val testRelation3 = LocalRelation( + AttributeReference("e", ShortType)(), + AttributeReference("f", StringType)(), + AttributeReference("g", DoubleType)(), + AttributeReference("h", DecimalType(10, 2))()) + val nestedRelation = LocalRelation( AttributeReference("top", StructType( StructField("duplicateField", StringType) :: From 831baf515faae0f12fae0f8b50297c05292e9e16 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 17 Jan 2016 23:44:27 -0800 Subject: [PATCH 10/17] fixed bugs. --- .../sql/catalyst/analysis/Analyzer.scala | 108 ++++++++++-------- .../sql/catalyst/analysis/AnalysisSuite.scala | 50 +++++++- .../apache/spark/sql/DataFrameJoinSuite.scala | 8 +- 3 files changed, 118 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0b8a3641d7e02..78ea0bec5e24b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -539,24 +539,16 @@ class Analyzer( // Assumption: all the conflicting attributes between left and right have been resolved val newChild = child transformUp { case p: Project => - val missingAttrs = - findNotResolvedMissingAttrs(p.outputSet, p.child.outputSet, missingResolvableAttrs) - p.copy(projectList = p.projectList ++ missingAttrs) + p.copy(projectList = p.projectList ++ + findNotResolvedMissingAttrs(p.outputSet, p.inputSet, missingResolvableAttrs)) case w: Window => - val missingAttrs = - findNotResolvedMissingAttrs(w.outputSet, w.child.outputSet, missingResolvableAttrs) - w.copy(projectList = w.projectList ++ missingAttrs) + w.copy(projectList = w.projectList ++ + findNotResolvedMissingAttrs(w.outputSet, w.inputSet, missingResolvableAttrs)) case a: Aggregate => - // Grouping expressions could already have the missing attributes. - // Do not add the duplicate attributes. - val newGroupExpressions = a.groupingExpressions ++ - findNotResolvedMissingAttrs( - a.groupingExpressions, a.child.outputSet, missingResolvableAttrs) val newAggregateExpressions = a.aggregateExpressions ++ findNotResolvedMissingAttrs( - a.aggregateExpressions, a.child.outputSet, missingResolvableAttrs) - a.copy(groupingExpressions = newGroupExpressions, - aggregateExpressions = newAggregateExpressions) + a.aggregateExpressions, a.inputSet, missingResolvableAttrs) + a.copy(aggregateExpressions = newAggregateExpressions) case o => o } @@ -568,19 +560,19 @@ class Analyzer( private def findNotResolvedMissingAttrs( outputSet: AttributeSet, - childOutputSet: AttributeSet, + inputSet: AttributeSet, missingAttrs: Seq[Attribute]): Seq[Attribute] = { val resolvedAttrs = - missingAttrs.filter(attr => childOutputSet.exists(_.semanticEquals(attr))) + missingAttrs.filter(attr => inputSet.exists(_.semanticEquals(attr))) resolvedAttrs.filterNot(attr => outputSet.exists(_.semanticEquals(attr))) } private def findNotResolvedMissingAttrs( outputSet: Seq[Expression], - childOutputSet: AttributeSet, + inputSet: AttributeSet, missingAttrs: Seq[Attribute]): Seq[Attribute] = { val resolvedAttrs = - missingAttrs.filter(attr => childOutputSet.exists(_.semanticEquals(attr))) + missingAttrs.filter(attr => inputSet.exists(_.semanticEquals(attr))) resolvedAttrs.filterNot(attr => outputSet.exists(_.semanticEquals(attr))) } @@ -592,36 +584,60 @@ class Analyzer( private def collectResolvableMissingAttrs( ordering: Seq[SortOrder], plans: mutable.Queue[LogicalPlan]): (Seq[SortOrder], Seq[Attribute]) = { - plans.dequeue() match { - // Only Windows, Project and Aggregate have projectList-like attribute. - // TODO: when the other operators have it, we should add a support too. - case un: UnaryNode - if un.isInstanceOf[Project] || un.isInstanceOf[Window] || un.isInstanceOf[Aggregate] => - val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child) - // If missingAttrs is non empty, that means we got it and return it; - // Otherwise, continue to traverse the tree. - if (missingAttrs.nonEmpty) (newOrdering, missingAttrs) - else { + if (plans.isEmpty) (Seq.empty[SortOrder], Seq.empty[Attribute]) + else { + plans.dequeue() match { + // Only Windows and Project have projectList-like attribute. + case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] => + val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child) + // If missingAttrs is non empty, that means we got it and return it; + // Otherwise, continue to traverse the tree. + if (missingAttrs.nonEmpty) (newOrdering, missingAttrs) + else { + plans.enqueue(un.child) + collectResolvableMissingAttrs(ordering, plans) + } + // Jump over the following UnaryNode types + // The output of these types is the same as their child's output + case un: UnaryNode + if un.isInstanceOf[Distinct] || + un.isInstanceOf[Filter] || + un.isInstanceOf[Limit] || + un.isInstanceOf[RedistributeData] || + un.isInstanceOf[Repartition] || + un.isInstanceOf[RepartitionByExpression] || + un.isInstanceOf[Sample] || + un.isInstanceOf[Sort] || + un.isInstanceOf[SortPartitions] || + un.isInstanceOf[Subquery] || + un.isInstanceOf[With] || + un.isInstanceOf[WithWindowDefinition] => + assert(un.inputSet == un.outputSet) plans.enqueue(un.child) collectResolvableMissingAttrs(ordering, plans) - } - // Skip the UnaryNode whose output is the same as their child's output - case un: UnaryNode if un.inputSet == un.outputSet => - plans.enqueue(un.child) - collectResolvableMissingAttrs(ordering, plans) - case join @ Join(left, right, joinType, _) => - joinType match { - case _ @ (Inner | LeftOuter | RightOuter | FullOuter) => - plans.enqueue(left, right) - collectResolvableMissingAttrs(ordering, plans) - // If we support LeftAnti, we should add it here - case _ @ LeftSemi => - plans.enqueue(left) - collectResolvableMissingAttrs(ordering, plans) - } - // If hitting the other unsupported operators, we are unable to resolve it - // and thus stop traversing the plan tree. - case other => (Seq.empty[SortOrder], Seq.empty[Attribute]) + case a: Aggregate => + val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child) + // For Aggregate, all the order by columns must be specified in group by clauses + if (missingAttrs.nonEmpty && + missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) { + (newOrdering, missingAttrs) + } + // If missingAttrs is empty, do not traverse its child and just try the next + else collectResolvableMissingAttrs(ordering, plans) + case join @ Join(left, right, joinType, _) => + joinType match { + case _ @ (Inner | LeftOuter | RightOuter | FullOuter) => + plans.enqueue(left, right) + collectResolvableMissingAttrs(ordering, plans) + // If we support LeftAnti, we should add it here + case _ @ LeftSemi => + plans.enqueue(left) + collectResolvableMissingAttrs(ordering, plans) + } + // If hitting the other unsupported operators, we are unable to resolve it, + // try the next until the queue is empty + case other => collectResolvableMissingAttrs(ordering, plans) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index c1e216bafac74..06d9da3a4701e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -76,22 +76,26 @@ class AnalysisSuite extends AnalysisTest { caseSensitive = false) } - test("resolve sort references - filter") { + test("resolve sort references - filter/limit") { val a = testRelation2.output(0) val b = testRelation2.output(1) val c = testRelation2.output(2) + // Case 1: one missing attribute is in the leaf node and another is in the unary node val plan1 = testRelation2 .where('a > 0).select('a, 'b) .where('b > 0).select('a) + .limit(4) .sortBy('b.asc, 'c.desc) val expected1 = testRelation2 .where(a.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, c) .where(b.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, c) + .limit(4) .sortBy(b.asc, c.desc) .select(a, b).select(a) checkAnalysis(plan1, expected1) + // Case 2: all the missing attributes are in the leaf node val plan2 = testRelation2 .where('a > 0).select('a) .where('a > 0).select('a) @@ -112,6 +116,7 @@ class AnalysisSuite extends AnalysisTest { val f = testRelation3.output(1) val h = testRelation3.output(3) + // Case 1: join itself can resolve all the missing attributes val plan1 = testRelation2.join(testRelation3) .where('a > 0).select('a, 'b) .sortBy('c.desc, 'h.asc) @@ -121,6 +126,8 @@ class AnalysisSuite extends AnalysisTest { .select(a, b) checkAnalysis(plan1, expected1) + // Case 2: join itself can resolve partial missing attributes + // and the remaining ones are resolvable in its left tree val plan2 = testRelation2.select('a, 'b).join(testRelation3) .where('a > 0).select('a, 'b) .sortBy('c.desc, 'h.asc) @@ -130,6 +137,7 @@ class AnalysisSuite extends AnalysisTest { .select(a, b, h).select(a, b) checkAnalysis(plan2, expected2) + // Case 3: both trees are needed to resolve all the missing attribute val plan3 = testRelation2.select('a, 'b).join(testRelation3.select('f)) .where('a > 0).select('a, 'b) .sortBy('c.desc, 'h.asc) @@ -138,6 +146,46 @@ class AnalysisSuite extends AnalysisTest { .sortBy(c.desc, h.asc) .select(a, b, c).select(a, b) checkAnalysis(plan3, expected3) + + // Case 4: right tree does not resolve any missing attribute + val plan4 = testRelation2.select('a, 'b).join(testRelation3.select('f)) + .where('a > 0).select('a, 'b) + .sortBy('h.asc) + val expected4 = testRelation2.select(a, b).join(testRelation3.select(f, h)) + .where(a.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, h) + .sortBy(h.asc) + .select(a, b) + checkAnalysis(plan4, expected4) + + // Case 5: left tree does not resolve any missing attribute + val plan5 = testRelation2.select('a, 'b).join(testRelation3.select('f)) + .where('a > 0).select('a, 'b) + .sortBy('c.desc) + val expected5 = testRelation2.select(a, b, c).join(testRelation3.select(f)) + .where(a.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, c) + .sortBy(c.desc) + .select(a, b) + checkAnalysis(plan5, expected5) + } + + test("resolve sort references - aggregate") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + val alias3 = count(a).as("a3") + + val plan = testRelation2 + .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3")) + .select('a, 'c, 'a3) + .orderBy('b.asc) + + val expected = testRelation2 + .groupBy(a, c, b)(a, c, alias3, b) + .select(a, c, alias3.toAttribute, b) + .orderBy(b.asc) + .select(a, c, alias3.toAttribute) + + checkAnalysis(plan, expected) } test("resolve relations") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 61ceb995e28bc..a5e5f156423cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -45,11 +45,17 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { test("join - sorted columns not in join's outputSet") { val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str_sort").as('df1) val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df2) + val df3 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df3) checkAnswer( df.join(df2, $"df1.int" === $"df2.int", "outer").select($"df1.int", $"df2.int2") - .orderBy('str_sort.asc, 'str.asc), + .orderBy('str_sort.asc, 'str.asc), Row(null, 6) :: Row(1, 3) :: Row(3, null) :: Nil) + + checkAnswer( + df2.join(df3, $"df2.int" === $"df3.int", "inner") + .select($"df2.int", $"df3.int").orderBy($"df2.str".desc), + Row(5, 5) :: Row(1, 1) :: Nil) } test("join - join using multiple columns and specifying join type") { From 7f113959034d5fa29040bce76c5a4366cff1fd42 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 18 Jan 2016 16:42:29 -0800 Subject: [PATCH 11/17] updated test case. --- .../sql/catalyst/analysis/AnalysisSuite.scala | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 06d9da3a4701e..d138f37ddc42c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -118,10 +118,10 @@ class AnalysisSuite extends AnalysisTest { // Case 1: join itself can resolve all the missing attributes val plan1 = testRelation2.join(testRelation3) - .where('a > 0).select('a, 'b) + .where('a > "str").select('a, 'b) .sortBy('c.desc, 'h.asc) val expected1 = testRelation2.join(testRelation3) - .where(a.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, c, h) + .where(a > "str").select(a, b, c, h) .sortBy(c.desc, h.asc) .select(a, b) checkAnalysis(plan1, expected1) @@ -129,40 +129,40 @@ class AnalysisSuite extends AnalysisTest { // Case 2: join itself can resolve partial missing attributes // and the remaining ones are resolvable in its left tree val plan2 = testRelation2.select('a, 'b).join(testRelation3) - .where('a > 0).select('a, 'b) + .where('a > "str").select('a, 'b) .sortBy('c.desc, 'h.asc) val expected2 = testRelation2.select(a, b, c).join(testRelation3) - .where(a.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, h, c) + .where(a > "str").select(a, b, h, c) .sortBy(c.desc, h.asc) .select(a, b, h).select(a, b) checkAnalysis(plan2, expected2) // Case 3: both trees are needed to resolve all the missing attribute val plan3 = testRelation2.select('a, 'b).join(testRelation3.select('f)) - .where('a > 0).select('a, 'b) + .where('a > "str").select('a, 'b) .sortBy('c.desc, 'h.asc) val expected3 = testRelation2.select(a, b, c).join(testRelation3.select(f, h)) - .where(a.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, c, h) + .where(a > "str").select(a, b, c, h) .sortBy(c.desc, h.asc) .select(a, b, c).select(a, b) checkAnalysis(plan3, expected3) // Case 4: right tree does not resolve any missing attribute val plan4 = testRelation2.select('a, 'b).join(testRelation3.select('f)) - .where('a > 0).select('a, 'b) + .where('a > "str").select('a, 'b) .sortBy('h.asc) val expected4 = testRelation2.select(a, b).join(testRelation3.select(f, h)) - .where(a.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, h) + .where(a > "str").select(a, b, h) .sortBy(h.asc) .select(a, b) checkAnalysis(plan4, expected4) // Case 5: left tree does not resolve any missing attribute val plan5 = testRelation2.select('a, 'b).join(testRelation3.select('f)) - .where('a > 0).select('a, 'b) + .where('a > "str").select('a, 'b) .sortBy('c.desc) val expected5 = testRelation2.select(a, b, c).join(testRelation3.select(f)) - .where(a.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, c) + .where(a > "str").select(a, b, c) .sortBy(c.desc) .select(a, b) checkAnalysis(plan5, expected5) From 598a673d1b01ebef6a89582f8e63ab487e465e71 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 18 Jan 2016 16:44:58 -0800 Subject: [PATCH 12/17] updated test case. --- .../sql/catalyst/analysis/AnalysisSuite.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index d138f37ddc42c..32e63b606e087 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -83,13 +83,13 @@ class AnalysisSuite extends AnalysisTest { // Case 1: one missing attribute is in the leaf node and another is in the unary node val plan1 = testRelation2 - .where('a > 0).select('a, 'b) - .where('b > 0).select('a) + .where('a > "str").select('a, 'b) + .where('b > "str").select('a) .limit(4) .sortBy('b.asc, 'c.desc) val expected1 = testRelation2 - .where(a.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, c) - .where(b.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, c) + .where(a > "str").select(a, b, c) + .where(b > "str").select(a, b, c) .limit(4) .sortBy(b.asc, c.desc) .select(a, b).select(a) @@ -97,12 +97,12 @@ class AnalysisSuite extends AnalysisTest { // Case 2: all the missing attributes are in the leaf node val plan2 = testRelation2 - .where('a > 0).select('a) - .where('a > 0).select('a) + .where('a > "str").select('a) + .where('a > "str").select('a) .sortBy('b.asc, 'c.desc) val expected2 = testRelation2 - .where(a.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, c) - .where(a.cast(DoubleType) > Literal(0).cast(DoubleType)).select(a, b, c) + .where(a > "str").select(a, b, c) + .where(a > "str").select(a, b, c) .sortBy(b.asc, c.desc) .select(a) checkAnalysis(plan2, expected2) From 1964884afee035f1e13bd248073656ec43de2223 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 19 Jan 2016 22:28:16 -0800 Subject: [PATCH 13/17] address comments. --- .../sql/catalyst/analysis/Analyzer.scala | 134 ++++++------------ .../sql/catalyst/analysis/AnalysisSuite.scala | 83 ++++------- 2 files changed, 72 insertions(+), 145 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 78ea0bec5e24b..3214e299f081e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -18,14 +18,12 @@ package org.apache.spark.sql.catalyst.analysis import scala.annotation.tailrec -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef @@ -76,7 +74,6 @@ class Analyzer( ResolveGroupingAnalytics :: ResolvePivot :: ResolveUpCast :: - ResolveAggregateFunctions :: ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: @@ -85,6 +82,7 @@ class Analyzer( ResolveWindowFrame :: ExtractWindowExpressions :: GlobalAggregates :: + ResolveAggregateFunctions :: DistinctAggregationRewriter(conf) :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), @@ -524,30 +522,31 @@ class Analyzer( */ object ResolveSortReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case s @ Sort(_, _, child) if !s.resolved && child.resolved => - val (newOrdering, missingResolvableAttrs) = - collectResolvableMissingAttrs(s.order, plans = mutable.Queue(child)) + // Here, this rule only resolves the missing sort references if the child is not Aggregate + // Another rule ResolveAggregateFunctions will resolve that case. + case s @ Sort(_, _, child) + if !s.resolved && child.resolved && !child.isInstanceOf[Aggregate] => + val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child) if (missingResolvableAttrs.isEmpty) { val unresolvableAttrs = s.order.filterNot(_.resolved) logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}") s // Nothing we can do here. Return original plan. - } - else { - // Transform the whole tree in post-order. Add into the self's outputSet - // all the children attributes that are part of missingResolvableAttrs. - // Assumption: all the conflicting attributes between left and right have been resolved + } else { + // Add the missing attributes into projectList of Project/Window or + // aggregateExpressions of Aggregate, if they are in the inputSet + // but not in the outputSet of the plan. val newChild = child transformUp { case p: Project => p.copy(projectList = p.projectList ++ - findNotResolvedMissingAttrs(p.outputSet, p.inputSet, missingResolvableAttrs)) + missingResolvableAttrs.filter((p.inputSet -- p.outputSet).contains)) case w: Window => w.copy(projectList = w.projectList ++ - findNotResolvedMissingAttrs(w.outputSet, w.inputSet, missingResolvableAttrs)) + missingResolvableAttrs.filter((w.inputSet -- w.outputSet).contains)) case a: Aggregate => - val newAggregateExpressions = a.aggregateExpressions ++ - findNotResolvedMissingAttrs( - a.aggregateExpressions, a.inputSet, missingResolvableAttrs) + val resolvableAttrs = missingResolvableAttrs.filter(a.groupingExpressions.contains) + val notResolvedAttrs = resolvableAttrs.filterNot(a.aggregateExpressions.contains) + val newAggregateExpressions = a.aggregateExpressions ++ notResolvedAttrs a.copy(aggregateExpressions = newAggregateExpressions) case o => o } @@ -558,24 +557,6 @@ class Analyzer( } } - private def findNotResolvedMissingAttrs( - outputSet: AttributeSet, - inputSet: AttributeSet, - missingAttrs: Seq[Attribute]): Seq[Attribute] = { - val resolvedAttrs = - missingAttrs.filter(attr => inputSet.exists(_.semanticEquals(attr))) - resolvedAttrs.filterNot(attr => outputSet.exists(_.semanticEquals(attr))) - } - - private def findNotResolvedMissingAttrs( - outputSet: Seq[Expression], - inputSet: AttributeSet, - missingAttrs: Seq[Attribute]): Seq[Attribute] = { - val resolvedAttrs = - missingAttrs.filter(attr => inputSet.exists(_.semanticEquals(attr))) - resolvedAttrs.filterNot(attr => outputSet.exists(_.semanticEquals(attr))) - } - /** * Traverse the tree until resolving the sorting attributes * Return all the resolvable missing sorting attributes @@ -583,61 +564,36 @@ class Analyzer( @tailrec private def collectResolvableMissingAttrs( ordering: Seq[SortOrder], - plans: mutable.Queue[LogicalPlan]): (Seq[SortOrder], Seq[Attribute]) = { - if (plans.isEmpty) (Seq.empty[SortOrder], Seq.empty[Attribute]) - else { - plans.dequeue() match { - // Only Windows and Project have projectList-like attribute. - case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] => - val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child) - // If missingAttrs is non empty, that means we got it and return it; - // Otherwise, continue to traverse the tree. - if (missingAttrs.nonEmpty) (newOrdering, missingAttrs) - else { - plans.enqueue(un.child) - collectResolvableMissingAttrs(ordering, plans) - } - // Jump over the following UnaryNode types - // The output of these types is the same as their child's output - case un: UnaryNode - if un.isInstanceOf[Distinct] || - un.isInstanceOf[Filter] || - un.isInstanceOf[Limit] || - un.isInstanceOf[RedistributeData] || - un.isInstanceOf[Repartition] || - un.isInstanceOf[RepartitionByExpression] || - un.isInstanceOf[Sample] || - un.isInstanceOf[Sort] || - un.isInstanceOf[SortPartitions] || - un.isInstanceOf[Subquery] || - un.isInstanceOf[With] || - un.isInstanceOf[WithWindowDefinition] => - assert(un.inputSet == un.outputSet) - plans.enqueue(un.child) - collectResolvableMissingAttrs(ordering, plans) - case a: Aggregate => - val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child) - // For Aggregate, all the order by columns must be specified in group by clauses - if (missingAttrs.nonEmpty && - missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) { - (newOrdering, missingAttrs) - } - // If missingAttrs is empty, do not traverse its child and just try the next - else collectResolvableMissingAttrs(ordering, plans) - case join @ Join(left, right, joinType, _) => - joinType match { - case _ @ (Inner | LeftOuter | RightOuter | FullOuter) => - plans.enqueue(left, right) - collectResolvableMissingAttrs(ordering, plans) - // If we support LeftAnti, we should add it here - case _ @ LeftSemi => - plans.enqueue(left) - collectResolvableMissingAttrs(ordering, plans) - } - // If hitting the other unsupported operators, we are unable to resolve it, - // try the next until the queue is empty - case other => collectResolvableMissingAttrs(ordering, plans) - } + plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { + plan match { + // Only Windows and Project have projectList-like attribute. + case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] => + val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child) + // If missingAttrs is non empty, that means we got it and return it; + // Otherwise, continue to traverse the tree. + if (missingAttrs.nonEmpty) { + (newOrdering, missingAttrs) + } else { + collectResolvableMissingAttrs(ordering, un.child) + } + case a: Aggregate => + val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child) + // For Aggregate, all the order by columns must be specified in group by clauses + if (missingAttrs.nonEmpty && + missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) { + (newOrdering, missingAttrs) + } else { + // If missingAttrs is empty, we are unable to resolve any unresolved missing attributes + (Seq.empty[SortOrder], Seq.empty[Attribute]) + } + // Jump over the following UnaryNode types + // The output of these types is the same as their child's output + case _: Distinct | + _: Filter | + _: RepartitionByExpression => + collectResolvableMissingAttrs(ordering, plan.asInstanceOf[UnaryNode].child) + // If hitting the other unsupported operators, we are unable to resolve it. + case other => (Seq.empty[SortOrder], Seq.empty[Attribute]) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 32e63b606e087..d4e544508f3f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -85,12 +85,10 @@ class AnalysisSuite extends AnalysisTest { val plan1 = testRelation2 .where('a > "str").select('a, 'b) .where('b > "str").select('a) - .limit(4) .sortBy('b.asc, 'c.desc) val expected1 = testRelation2 .where(a > "str").select(a, b, c) .where(b > "str").select(a, b, c) - .limit(4) .sortBy(b.asc, c.desc) .select(a, b).select(a) checkAnalysis(plan1, expected1) @@ -112,80 +110,53 @@ class AnalysisSuite extends AnalysisTest { val a = testRelation2.output(0) val b = testRelation2.output(1) val c = testRelation2.output(2) - - val f = testRelation3.output(1) val h = testRelation3.output(3) - // Case 1: join itself can resolve all the missing attributes - val plan1 = testRelation2.join(testRelation3) - .where('a > "str").select('a, 'b) - .sortBy('c.desc, 'h.asc) - val expected1 = testRelation2.join(testRelation3) - .where(a > "str").select(a, b, c, h) - .sortBy(c.desc, h.asc) - .select(a, b) - checkAnalysis(plan1, expected1) - - // Case 2: join itself can resolve partial missing attributes - // and the remaining ones are resolvable in its left tree - val plan2 = testRelation2.select('a, 'b).join(testRelation3) - .where('a > "str").select('a, 'b) - .sortBy('c.desc, 'h.asc) - val expected2 = testRelation2.select(a, b, c).join(testRelation3) - .where(a > "str").select(a, b, h, c) - .sortBy(c.desc, h.asc) - .select(a, b, h).select(a, b) - checkAnalysis(plan2, expected2) - - // Case 3: both trees are needed to resolve all the missing attribute - val plan3 = testRelation2.select('a, 'b).join(testRelation3.select('f)) + // Case: join itself can resolve all the missing attributes + val plan = testRelation2.join(testRelation3) .where('a > "str").select('a, 'b) .sortBy('c.desc, 'h.asc) - val expected3 = testRelation2.select(a, b, c).join(testRelation3.select(f, h)) + val expected = testRelation2.join(testRelation3) .where(a > "str").select(a, b, c, h) .sortBy(c.desc, h.asc) - .select(a, b, c).select(a, b) - checkAnalysis(plan3, expected3) - - // Case 4: right tree does not resolve any missing attribute - val plan4 = testRelation2.select('a, 'b).join(testRelation3.select('f)) - .where('a > "str").select('a, 'b) - .sortBy('h.asc) - val expected4 = testRelation2.select(a, b).join(testRelation3.select(f, h)) - .where(a > "str").select(a, b, h) - .sortBy(h.asc) .select(a, b) - checkAnalysis(plan4, expected4) - - // Case 5: left tree does not resolve any missing attribute - val plan5 = testRelation2.select('a, 'b).join(testRelation3.select('f)) - .where('a > "str").select('a, 'b) - .sortBy('c.desc) - val expected5 = testRelation2.select(a, b, c).join(testRelation3.select(f)) - .where(a > "str").select(a, b, c) - .sortBy(c.desc) - .select(a, b) - checkAnalysis(plan5, expected5) + checkAnalysis(plan, expected) } test("resolve sort references - aggregate") { val a = testRelation2.output(0) val b = testRelation2.output(1) val c = testRelation2.output(2) - val alias3 = count(a).as("a3") + val alias_a3 = count(a).as("a3") + val alias_b = b.as("aggOrder") - val plan = testRelation2 + // Case 1: when the child of Sort is not Aggregate, + // the sort reference is handled by the rule ResolveSortReferences + val plan1 = testRelation2 .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3")) .select('a, 'c, 'a3) .orderBy('b.asc) - val expected = testRelation2 - .groupBy(a, c, b)(a, c, alias3, b) - .select(a, c, alias3.toAttribute, b) + val expected1 = testRelation2 + .groupBy(a, c, b)(a, c, alias_a3, b) + .select(a, c, alias_a3.toAttribute, b) .orderBy(b.asc) - .select(a, c, alias3.toAttribute) + .select(a, c, alias_a3.toAttribute) - checkAnalysis(plan, expected) + checkAnalysis(plan1, expected1) + + // Case 2: when the child of Sort is Aggregate, + // the sort reference is handled by the rule ResolveAggregateFunctions + val plan2 = testRelation2 + .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3")) + .orderBy('b.asc) + + val expected2 = testRelation2 + .groupBy(a, c, b)(a, c, alias_a3, alias_b) + .orderBy(alias_b.toAttribute.asc) + .select(a, c, alias_a3.toAttribute) + + checkAnalysis(plan2, expected2) } test("resolve relations") { From ba02f4695e4bfd07a9bef72f783bef3894d8191e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 30 Jan 2016 15:12:01 -0800 Subject: [PATCH 14/17] addressed comments. --- .../sql/catalyst/analysis/Analyzer.scala | 122 +++++++++--------- 1 file changed, 62 insertions(+), 60 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3214e299f081e..cb2b4d603bf1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -522,10 +522,13 @@ class Analyzer( */ object ResolveSortReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - // Here, this rule only resolves the missing sort references if the child is not Aggregate - // Another rule ResolveAggregateFunctions will resolve that case. - case s @ Sort(_, _, child) - if !s.resolved && child.resolved && !child.isInstanceOf[Aggregate] => + case s @ Sort(_, _, a: Aggregate) if a.resolved => + // Here, it finds aggregate expressions in ORDER BY clauses but these expressions are + // not in the aggregate operator. These expressions are pushed down to the underlying + // aggregate operator and then projected away after the original operator. + ResolveAggregateFunctions.resolveAggregateFunctionsInSort(sort = s, aggregate = a) + + case s @ Sort(_, _, child) if !s.resolved && child.resolved => val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child) if (missingResolvableAttrs.isEmpty) { @@ -601,7 +604,7 @@ class Analyzer( * Try to resolve the sort ordering and returns it with a list of attributes that are missing * from the plan but are present in the child. */ - def resolveAndFindMissing( + private def resolveAndFindMissing( ordering: Seq[SortOrder], plan: LogicalPlan, child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { @@ -704,66 +707,65 @@ class Analyzer( } else { filter } + } - case sort @ Sort(sortOrder, global, aggregate: Aggregate) - if aggregate.resolved => - - // Try resolving the ordering as though it is in the aggregate clause. - try { - val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) - val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) - val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) - val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] - val resolvedAliasedOrdering: Seq[Alias] = - resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] - - // If we pass the analysis check, then the ordering expressions should only reference to - // aggregate expressions or grouping expressions, and it's safe to push them down to - // Aggregate. - checkAnalysis(resolvedAggregate) - - val originalAggExprs = aggregate.aggregateExpressions.map( - CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) - - // If the ordering expression is same with original aggregate expression, we don't need - // to push down this ordering expression and can reference the original aggregate - // expression instead. - val needsPushDown = ArrayBuffer.empty[NamedExpression] - val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map { - case (evaluated, order) => - val index = originalAggExprs.indexWhere { - case Alias(child, _) => child semanticEquals evaluated.child - case other => other semanticEquals evaluated.child - } + // Note, aggregate expressions in ORDER BY clauses are resolved in ResolveSortReferences. + def resolveAggregateFunctionsInSort(sort: Sort, aggregate: Aggregate): LogicalPlan = { + try { + val unresolvedSortOrders = sort.order.filter(s => !s.resolved || containsAggregate(s)) + val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) + val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) + val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + val resolvedAliasedOrdering: Seq[Alias] = + resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] + + // If we pass the analysis check, then the ordering expressions should only reference to + // aggregate expressions or grouping expressions, and it's safe to push them down to + // Aggregate. + checkAnalysis(resolvedAggregate) + + val originalAggExprs = aggregate.aggregateExpressions.map( + CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + + // If the ordering expression is same with original aggregate expression, we don't need + // to push down this ordering expression and can reference the original aggregate + // expression instead. + val needsPushDown = ArrayBuffer.empty[NamedExpression] + val evaluatedOrderings = resolvedAliasedOrdering.zip(sort.order).map { + case (evaluated, order) => + val index = originalAggExprs.indexWhere { + case Alias(child, _) => child semanticEquals evaluated.child + case other => other semanticEquals evaluated.child + } - if (index == -1) { - needsPushDown += evaluated - order.copy(child = evaluated.toAttribute) - } else { - order.copy(child = originalAggExprs(index).toAttribute) - } - } + if (index == -1) { + needsPushDown += evaluated + order.copy(child = evaluated.toAttribute) + } else { + order.copy(child = originalAggExprs(index).toAttribute) + } + } - val sortOrdersMap = unresolvedSortOrders - .map(new TreeNodeRef(_)) - .zip(evaluatedOrderings) - .toMap - val finalSortOrders = sortOrder.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s)) + val sortOrdersMap = unresolvedSortOrders + .map(new TreeNodeRef(_)) + .zip(evaluatedOrderings) + .toMap + val finalSortOrders = sort.order.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s)) - // Since we don't rely on sort.resolved as the stop condition for this rule, - // we need to check this and prevent applying this rule multiple times - if (sortOrder == finalSortOrders) { - sort - } else { - Project(aggregate.output, - Sort(finalSortOrders, global, - aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) - } - } catch { - // Attempting to resolve in the aggregate can result in ambiguity. When this happens, - // just return the original plan. - case ae: AnalysisException => sort + // Since we don't rely on sort.resolved as the stop condition for this rule, + // we need to check this and prevent applying this rule multiple times + if (sort.order == finalSortOrders) { + sort + } else { + Project(aggregate.output, + Sort(finalSortOrders, sort.global, + aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) } + } catch { + // Attempting to resolve in the aggregate can result in ambiguity. When this happens, + // just return the original plan. + case ae: AnalysisException => sort + } } protected def containsAggregate(condition: Expression): Boolean = { From 5bfda35e95f7b189517f2098ea5a8fd35fd10ec6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 30 Jan 2016 22:56:25 -0800 Subject: [PATCH 15/17] address comments. --- .../sql/catalyst/analysis/Analyzer.scala | 118 +++++++++--------- 1 file changed, 58 insertions(+), 60 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cb2b4d603bf1b..20a5154c9f24d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -522,11 +522,8 @@ class Analyzer( */ object ResolveSortReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case s @ Sort(_, _, a: Aggregate) if a.resolved => - // Here, it finds aggregate expressions in ORDER BY clauses but these expressions are - // not in the aggregate operator. These expressions are pushed down to the underlying - // aggregate operator and then projected away after the original operator. - ResolveAggregateFunctions.resolveAggregateFunctionsInSort(sort = s, aggregate = a) + // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions + case sa: Sort if sa.order.exists(ResolveAggregateFunctions.containsAggregate) => sa case s @ Sort(_, _, child) if !s.resolved && child.resolved => val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child) @@ -707,68 +704,69 @@ class Analyzer( } else { filter } - } - // Note, aggregate expressions in ORDER BY clauses are resolved in ResolveSortReferences. - def resolveAggregateFunctionsInSort(sort: Sort, aggregate: Aggregate): LogicalPlan = { - try { - val unresolvedSortOrders = sort.order.filter(s => !s.resolved || containsAggregate(s)) - val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) - val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) - val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] - val resolvedAliasedOrdering: Seq[Alias] = - resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] - - // If we pass the analysis check, then the ordering expressions should only reference to - // aggregate expressions or grouping expressions, and it's safe to push them down to - // Aggregate. - checkAnalysis(resolvedAggregate) - - val originalAggExprs = aggregate.aggregateExpressions.map( - CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) - - // If the ordering expression is same with original aggregate expression, we don't need - // to push down this ordering expression and can reference the original aggregate - // expression instead. - val needsPushDown = ArrayBuffer.empty[NamedExpression] - val evaluatedOrderings = resolvedAliasedOrdering.zip(sort.order).map { - case (evaluated, order) => - val index = originalAggExprs.indexWhere { - case Alias(child, _) => child semanticEquals evaluated.child - case other => other semanticEquals evaluated.child - } + case sort @ Sort(sortOrder, global, aggregate: Aggregate) + if aggregate.resolved => - if (index == -1) { - needsPushDown += evaluated - order.copy(child = evaluated.toAttribute) - } else { - order.copy(child = originalAggExprs(index).toAttribute) - } - } + // Try resolving the ordering as though it is in the aggregate clause. + try { + val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) + val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) + val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) + val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + val resolvedAliasedOrdering: Seq[Alias] = + resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] + + // If we pass the analysis check, then the ordering expressions should only reference to + // aggregate expressions or grouping expressions, and it's safe to push them down to + // Aggregate. + checkAnalysis(resolvedAggregate) + + val originalAggExprs = aggregate.aggregateExpressions.map( + CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + + // If the ordering expression is same with original aggregate expression, we don't need + // to push down this ordering expression and can reference the original aggregate + // expression instead. + val needsPushDown = ArrayBuffer.empty[NamedExpression] + val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map { + case (evaluated, order) => + val index = originalAggExprs.indexWhere { + case Alias(child, _) => child semanticEquals evaluated.child + case other => other semanticEquals evaluated.child + } - val sortOrdersMap = unresolvedSortOrders - .map(new TreeNodeRef(_)) - .zip(evaluatedOrderings) - .toMap - val finalSortOrders = sort.order.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s)) + if (index == -1) { + needsPushDown += evaluated + order.copy(child = evaluated.toAttribute) + } else { + order.copy(child = originalAggExprs(index).toAttribute) + } + } - // Since we don't rely on sort.resolved as the stop condition for this rule, - // we need to check this and prevent applying this rule multiple times - if (sort.order == finalSortOrders) { - sort - } else { - Project(aggregate.output, - Sort(finalSortOrders, sort.global, - aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) + val sortOrdersMap = unresolvedSortOrders + .map(new TreeNodeRef(_)) + .zip(evaluatedOrderings) + .toMap + val finalSortOrders = sortOrder.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s)) + + // Since we don't rely on sort.resolved as the stop condition for this rule, + // we need to check this and prevent applying this rule multiple times + if (sortOrder == finalSortOrders) { + sort + } else { + Project(aggregate.output, + Sort(finalSortOrders, global, + aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) + } + } catch { + // Attempting to resolve in the aggregate can result in ambiguity. When this happens, + // just return the original plan. + case ae: AnalysisException => sort } - } catch { - // Attempting to resolve in the aggregate can result in ambiguity. When this happens, - // just return the original plan. - case ae: AnalysisException => sort - } } - protected def containsAggregate(condition: Expression): Boolean = { + def containsAggregate(condition: Expression): Boolean = { condition.find(_.isInstanceOf[AggregateExpression]).isDefined } } From ddfebbf2bde7a30f68271fbc8ef11705a744f0c8 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 30 Jan 2016 23:46:50 -0800 Subject: [PATCH 16/17] address comments. --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 20a5154c9f24d..406aaefb3646b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -523,7 +523,7 @@ class Analyzer( object ResolveSortReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions - case sa: Sort if sa.order.exists(ResolveAggregateFunctions.containsAggregate) => sa + case sa @ Sort(_, _, child: Aggregate) => sa case s @ Sort(_, _, child) if !s.resolved && child.resolved => val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child) From c2964daad833f3a9b1c05aa5d082e356a776c70c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 31 Jan 2016 23:26:34 -0800 Subject: [PATCH 17/17] Added a test case that we need to fix in the next PR. --- .../sql/hive/execution/SQLQuerySuite.scala | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 0572635cc04b6..bd7ba22ec2471 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1019,6 +1019,35 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2))) } + // todo: fix this test case by reimplementing the function ResolveAggregateFunctions + ignore("window function: Pushing aggregate Expressions in Sort to Aggregate") { + val data = Seq( + WindowData(1, "d", 10), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 11) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + checkAnswer( + sql( + """ + |select area, sum(product) over () as c from windowData + |where product > 3 group by area, product + |having avg(month) > 0 order by avg(month), product + """.stripMargin), + Seq( + ("a", 51), + ("b", 51), + ("b", 51), + ("c", 51), + ("c", 51), + ("d", 51) + ).map(i => Row(i._1, i._2))) + } + test("window function: multiple window expressions in a single expression") { val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") nums.registerTempTable("nums")