diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dadea6b54a946..406aaefb3646b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException @@ -440,7 +441,7 @@ class Analyzer( } // When resolve `SortOrder`s in Sort based on child, don't report errors as - // we still have chance to resolve it based on grandchild + // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => val newOrdering = resolveSortOrders(ordering, child, throws = false) Sort(newOrdering, global, child) @@ -521,38 +522,96 @@ class Analyzer( */ object ResolveSortReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case s @ Sort(ordering, global, p @ Project(projectList, child)) - if !s.resolved && p.resolved => - val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child) + // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions + case sa @ Sort(_, _, child: Aggregate) => sa - // If this rule was not a no-op, return the transformed plan, otherwise return the original. - if (missing.nonEmpty) { - // Add missing attributes and then project them away after the sort. - Project(p.output, - Sort(newOrdering, global, - Project(projectList ++ missing, child))) - } else { - logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") + case s @ Sort(_, _, child) if !s.resolved && child.resolved => + val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child) + + if (missingResolvableAttrs.isEmpty) { + val unresolvableAttrs = s.order.filterNot(_.resolved) + logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}") s // Nothing we can do here. Return original plan. + } else { + // Add the missing attributes into projectList of Project/Window or + // aggregateExpressions of Aggregate, if they are in the inputSet + // but not in the outputSet of the plan. + val newChild = child transformUp { + case p: Project => + p.copy(projectList = p.projectList ++ + missingResolvableAttrs.filter((p.inputSet -- p.outputSet).contains)) + case w: Window => + w.copy(projectList = w.projectList ++ + missingResolvableAttrs.filter((w.inputSet -- w.outputSet).contains)) + case a: Aggregate => + val resolvableAttrs = missingResolvableAttrs.filter(a.groupingExpressions.contains) + val notResolvedAttrs = resolvableAttrs.filterNot(a.aggregateExpressions.contains) + val newAggregateExpressions = a.aggregateExpressions ++ notResolvedAttrs + a.copy(aggregateExpressions = newAggregateExpressions) + case o => o + } + + // Add missing attributes and then project them away after the sort. + Project(child.output, + Sort(newOrdering, s.global, newChild)) } } /** - * Given a child and a grandchild that are present beneath a sort operator, try to resolve - * the sort ordering and returns it with a list of attributes that are missing from the - * child but are present in the grandchild. + * Traverse the tree until resolving the sorting attributes + * Return all the resolvable missing sorting attributes + */ + @tailrec + private def collectResolvableMissingAttrs( + ordering: Seq[SortOrder], + plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { + plan match { + // Only Windows and Project have projectList-like attribute. + case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] => + val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child) + // If missingAttrs is non empty, that means we got it and return it; + // Otherwise, continue to traverse the tree. + if (missingAttrs.nonEmpty) { + (newOrdering, missingAttrs) + } else { + collectResolvableMissingAttrs(ordering, un.child) + } + case a: Aggregate => + val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child) + // For Aggregate, all the order by columns must be specified in group by clauses + if (missingAttrs.nonEmpty && + missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) { + (newOrdering, missingAttrs) + } else { + // If missingAttrs is empty, we are unable to resolve any unresolved missing attributes + (Seq.empty[SortOrder], Seq.empty[Attribute]) + } + // Jump over the following UnaryNode types + // The output of these types is the same as their child's output + case _: Distinct | + _: Filter | + _: RepartitionByExpression => + collectResolvableMissingAttrs(ordering, plan.asInstanceOf[UnaryNode].child) + // If hitting the other unsupported operators, we are unable to resolve it. + case other => (Seq.empty[SortOrder], Seq.empty[Attribute]) + } + } + + /** + * Try to resolve the sort ordering and returns it with a list of attributes that are missing + * from the plan but are present in the child. */ - def resolveAndFindMissing( + private def resolveAndFindMissing( ordering: Seq[SortOrder], - child: LogicalPlan, - grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { - val newOrdering = resolveSortOrders(ordering, grandchild, throws = true) + plan: LogicalPlan, + child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { + val newOrdering = resolveSortOrders(ordering, child, throws = false) // Construct a set that contains all of the attributes that we need to evaluate the // ordering. val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved) // Figure out which ones are missing from the projection, so that we can add them and // remove them after the sort. - val missingInProject = requiredAttributes -- child.output + val missingInProject = requiredAttributes -- plan.outputSet // It is important to return the new SortOrders here, instead of waiting for the standard // resolving process as adding attributes to the project below can actually introduce // ambiguity that was not present before. @@ -707,7 +766,7 @@ class Analyzer( } } - protected def containsAggregate(condition: Expression): Boolean = { + def containsAggregate(condition: Expression): Boolean = { condition.find(_.isInstanceOf[AggregateExpression]).isDefined } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 975cd87d090e4..d4e544508f3f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -76,6 +76,89 @@ class AnalysisSuite extends AnalysisTest { caseSensitive = false) } + test("resolve sort references - filter/limit") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + + // Case 1: one missing attribute is in the leaf node and another is in the unary node + val plan1 = testRelation2 + .where('a > "str").select('a, 'b) + .where('b > "str").select('a) + .sortBy('b.asc, 'c.desc) + val expected1 = testRelation2 + .where(a > "str").select(a, b, c) + .where(b > "str").select(a, b, c) + .sortBy(b.asc, c.desc) + .select(a, b).select(a) + checkAnalysis(plan1, expected1) + + // Case 2: all the missing attributes are in the leaf node + val plan2 = testRelation2 + .where('a > "str").select('a) + .where('a > "str").select('a) + .sortBy('b.asc, 'c.desc) + val expected2 = testRelation2 + .where(a > "str").select(a, b, c) + .where(a > "str").select(a, b, c) + .sortBy(b.asc, c.desc) + .select(a) + checkAnalysis(plan2, expected2) + } + + test("resolve sort references - join") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + val h = testRelation3.output(3) + + // Case: join itself can resolve all the missing attributes + val plan = testRelation2.join(testRelation3) + .where('a > "str").select('a, 'b) + .sortBy('c.desc, 'h.asc) + val expected = testRelation2.join(testRelation3) + .where(a > "str").select(a, b, c, h) + .sortBy(c.desc, h.asc) + .select(a, b) + checkAnalysis(plan, expected) + } + + test("resolve sort references - aggregate") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + val alias_a3 = count(a).as("a3") + val alias_b = b.as("aggOrder") + + // Case 1: when the child of Sort is not Aggregate, + // the sort reference is handled by the rule ResolveSortReferences + val plan1 = testRelation2 + .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3")) + .select('a, 'c, 'a3) + .orderBy('b.asc) + + val expected1 = testRelation2 + .groupBy(a, c, b)(a, c, alias_a3, b) + .select(a, c, alias_a3.toAttribute, b) + .orderBy(b.asc) + .select(a, c, alias_a3.toAttribute) + + checkAnalysis(plan1, expected1) + + // Case 2: when the child of Sort is Aggregate, + // the sort reference is handled by the rule ResolveAggregateFunctions + val plan2 = testRelation2 + .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3")) + .orderBy('b.asc) + + val expected2 = testRelation2 + .groupBy(a, c, b)(a, c, alias_a3, alias_b) + .orderBy(alias_b.toAttribute.asc) + .select(a, c, alias_a3.toAttribute) + + checkAnalysis(plan2, expected2) + } + test("resolve relations") { assertAnalysisError( UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq("Table not found: tAbLe")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala index bc07b609a3413..3741a6ba95a86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala @@ -31,6 +31,12 @@ object TestRelations { AttributeReference("d", DecimalType(10, 2))(), AttributeReference("e", ShortType)()) + val testRelation3 = LocalRelation( + AttributeReference("e", ShortType)(), + AttributeReference("f", StringType)(), + AttributeReference("g", DoubleType)(), + AttributeReference("h", DecimalType(10, 2))()) + val nestedRelation = LocalRelation( AttributeReference("top", StructType( StructField("duplicateField", StringType) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index c17be8ace9287..a5e5f156423cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -42,6 +42,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) } + test("join - sorted columns not in join's outputSet") { + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str_sort").as('df1) + val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df2) + val df3 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df3) + + checkAnswer( + df.join(df2, $"df1.int" === $"df2.int", "outer").select($"df1.int", $"df2.int2") + .orderBy('str_sort.asc, 'str.asc), + Row(null, 6) :: Row(1, 3) :: Row(3, null) :: Nil) + + checkAnswer( + df2.join(df3, $"df2.int" === $"df3.int", "inner") + .select($"df2.int", $"df3.int").orderBy($"df2.str".desc), + Row(5, 5) :: Row(1, 1) :: Nil) + } + test("join - join using multiple columns and specifying join type") { val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str") val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d6c140dfea9ed..771f392bbb97d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -898,6 +898,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(expected === actual) } + test("Sorting columns are not in Filter and Project") { + checkAnswer( + upperCaseData.filter('N > 1).select('N).filter('N < 6).orderBy('L.asc), + Row(2) :: Row(3) :: Row(4) :: Row(5) :: Nil) + } + test("SPARK-9323: DataFrame.orderBy should support nested column name") { val df = sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": 1}}""" :: Nil)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 61d5aa7ae6b31..bd7ba22ec2471 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -775,7 +775,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { """.stripMargin), (2 to 6).map(i => Row(i))) } - test("window function: udaf with aggregate expressin") { + test("window function: udaf with aggregate expression") { val data = Seq( WindowData(1, "a", 5), WindowData(2, "a", 6), @@ -966,6 +966,88 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("window function: Sorting columns are not in Project") { + val data = Seq( + WindowData(1, "d", 10), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 11) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + checkAnswer( + sql("select month, product, sum(product + 1) over() from windowData order by area"), + Seq( + (2, 6, 57), + (3, 7, 57), + (4, 8, 57), + (5, 9, 57), + (6, 11, 57), + (1, 10, 57) + ).map(i => Row(i._1, i._2, i._3))) + + checkAnswer( + sql( + """ + |select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1 + |from (select month, area, product as p, 1 as tmp1 from windowData) tmp order by p + """.stripMargin), + Seq( + ("a", 2), + ("b", 2), + ("b", 3), + ("c", 2), + ("d", 2), + ("c", 3) + ).map(i => Row(i._1, i._2))) + + checkAnswer( + sql( + """ + |select area, rank() over (partition by area order by month) as c1 + |from windowData group by product, area, month order by product, area + """.stripMargin), + Seq( + ("a", 1), + ("b", 1), + ("b", 2), + ("c", 1), + ("d", 1), + ("c", 2) + ).map(i => Row(i._1, i._2))) + } + + // todo: fix this test case by reimplementing the function ResolveAggregateFunctions + ignore("window function: Pushing aggregate Expressions in Sort to Aggregate") { + val data = Seq( + WindowData(1, "d", 10), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 11) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + checkAnswer( + sql( + """ + |select area, sum(product) over () as c from windowData + |where product > 3 group by area, product + |having avg(month) > 0 order by avg(month), product + """.stripMargin), + Seq( + ("a", 51), + ("b", 51), + ("b", 51), + ("c", 51), + ("c", 51), + ("d", 51) + ).map(i => Row(i._1, i._2))) + } + test("window function: multiple window expressions in a single expression") { val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") nums.registerTempTable("nums")