Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12705] [SPARK-10777] [SQL] Analyzer Rule ResolveSortReferences #10678

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.analysis

import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.AnalysisException
Expand Down Expand Up @@ -440,7 +441,7 @@ class Analyzer(
}

// When resolve `SortOrder`s in Sort based on child, don't report errors as
// we still have chance to resolve it based on grandchild
// we still have chance to resolve it based on its descendants
case s @ Sort(ordering, global, child) if child.resolved && !s.resolved =>
val newOrdering = resolveSortOrders(ordering, child, throws = false)
Sort(newOrdering, global, child)
Expand Down Expand Up @@ -521,38 +522,96 @@ class Analyzer(
*/
object ResolveSortReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case s @ Sort(ordering, global, p @ Project(projectList, child))
if !s.resolved && p.resolved =>
val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child)
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, child: Aggregate) => sa

// If this rule was not a no-op, return the transformed plan, otherwise return the original.
if (missing.nonEmpty) {
// Add missing attributes and then project them away after the sort.
Project(p.output,
Sort(newOrdering, global,
Project(projectList ++ missing, child)))
} else {
logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}")
case s @ Sort(_, _, child) if !s.resolved && child.resolved =>
val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child)

if (missingResolvableAttrs.isEmpty) {
val unresolvableAttrs = s.order.filterNot(_.resolved)
logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}")
s // Nothing we can do here. Return original plan.
} else {
// Add the missing attributes into projectList of Project/Window or
// aggregateExpressions of Aggregate, if they are in the inputSet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: too many spaces

// but not in the outputSet of the plan.
val newChild = child transformUp {
case p: Project =>
p.copy(projectList = p.projectList ++
missingResolvableAttrs.filter((p.inputSet -- p.outputSet).contains))
case w: Window =>
w.copy(projectList = w.projectList ++
missingResolvableAttrs.filter((w.inputSet -- w.outputSet).contains))
case a: Aggregate =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case can never happen right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the following query, we can trigger this case. Actually, this query is based on the failed TPCDS queries. Thus, we added it as a test case. The column product is in the group-by clause but not appeared in aggregateExpressions. Thus, we hit this error if we want to sort the results by product.
select area, rank() over (partition by area order by month) as c1 from windowData group by product, area, month order by product, area

If we remove this case, we will get this error:

Failed to analyze query: org.apache.spark.sql.AnalysisException: resolved attribute(s) product#2 missing from area#1,c1#39 in operator !Sort [product#2 ASC,area#1 ASC], true;
Project [area#1,c1#48]
+- !Sort [product#2 ASC,area#1 ASC], true
   +- Project [area#1,c1#48]
      +- Project [area#1,month#0,c1#48,c1#48]
         +- Window [area#1,month#0], [rank(month#0) windowspecdefinition(area#1,month#0 ASC,ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS c1#48], [area#1], [month#0 ASC]
            +- Aggregate [product#2,area#1,month#0], [area#1,month#0]
               +- Subquery windowdata
                  +- LogicalRDD [month#0,area#1,product#2], MapPartitionsRDD[1] at apply at Transformer.scala:22

val resolvableAttrs = missingResolvableAttrs.filter(a.groupingExpressions.contains)
val notResolvedAttrs = resolvableAttrs.filterNot(a.aggregateExpressions.contains)
val newAggregateExpressions = a.aggregateExpressions ++ notResolvedAttrs
a.copy(aggregateExpressions = newAggregateExpressions)
case o => o
}

// Add missing attributes and then project them away after the sort.
Project(child.output,
Sort(newOrdering, s.global, newChild))
}
}

/**
* Given a child and a grandchild that are present beneath a sort operator, try to resolve
* the sort ordering and returns it with a list of attributes that are missing from the
* child but are present in the grandchild.
* Traverse the tree until resolving the sorting attributes
* Return all the resolvable missing sorting attributes
*/
@tailrec
private def collectResolvableMissingAttrs(
ordering: Seq[SortOrder],
plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
plan match {
// Only Windows and Project have projectList-like attribute.
case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] =>
val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child)
// If missingAttrs is non empty, that means we got it and return it;
// Otherwise, continue to traverse the tree.
if (missingAttrs.nonEmpty) {
(newOrdering, missingAttrs)
} else {
collectResolvableMissingAttrs(ordering, un.child)
}
case a: Aggregate =>
val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child)
// For Aggregate, all the order by columns must be specified in group by clauses
if (missingAttrs.nonEmpty &&
missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) {
(newOrdering, missingAttrs)
} else {
// If missingAttrs is empty, we are unable to resolve any unresolved missing attributes
(Seq.empty[SortOrder], Seq.empty[Attribute])
}
// Jump over the following UnaryNode types
// The output of these types is the same as their child's output
case _: Distinct |
_: Filter |
_: RepartitionByExpression =>
collectResolvableMissingAttrs(ordering, plan.asInstanceOf[UnaryNode].child)
// If hitting the other unsupported operators, we are unable to resolve it.
case other => (Seq.empty[SortOrder], Seq.empty[Attribute])
}
}

/**
* Try to resolve the sort ordering and returns it with a list of attributes that are missing
* from the plan but are present in the child.
*/
def resolveAndFindMissing(
private def resolveAndFindMissing(
ordering: Seq[SortOrder],
child: LogicalPlan,
grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
val newOrdering = resolveSortOrders(ordering, grandchild, throws = true)
plan: LogicalPlan,
child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
val newOrdering = resolveSortOrders(ordering, child, throws = false)
// Construct a set that contains all of the attributes that we need to evaluate the
// ordering.
val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved)
// Figure out which ones are missing from the projection, so that we can add them and
// remove them after the sort.
val missingInProject = requiredAttributes -- child.output
val missingInProject = requiredAttributes -- plan.outputSet
// It is important to return the new SortOrders here, instead of waiting for the standard
// resolving process as adding attributes to the project below can actually introduce
// ambiguity that was not present before.
Expand Down Expand Up @@ -707,7 +766,7 @@ class Analyzer(
}
}

protected def containsAggregate(condition: Expression): Boolean = {
def containsAggregate(condition: Expression): Boolean = {
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,89 @@ class AnalysisSuite extends AnalysisTest {
caseSensitive = false)
}

test("resolve sort references - filter/limit") {
val a = testRelation2.output(0)
val b = testRelation2.output(1)
val c = testRelation2.output(2)

// Case 1: one missing attribute is in the leaf node and another is in the unary node
val plan1 = testRelation2
.where('a > "str").select('a, 'b)
.where('b > "str").select('a)
.sortBy('b.asc, 'c.desc)
val expected1 = testRelation2
.where(a > "str").select(a, b, c)
.where(b > "str").select(a, b, c)
.sortBy(b.asc, c.desc)
.select(a, b).select(a)
checkAnalysis(plan1, expected1)

// Case 2: all the missing attributes are in the leaf node
val plan2 = testRelation2
.where('a > "str").select('a)
.where('a > "str").select('a)
.sortBy('b.asc, 'c.desc)
val expected2 = testRelation2
.where(a > "str").select(a, b, c)
.where(a > "str").select(a, b, c)
.sortBy(b.asc, c.desc)
.select(a)
checkAnalysis(plan2, expected2)
}

test("resolve sort references - join") {
val a = testRelation2.output(0)
val b = testRelation2.output(1)
val c = testRelation2.output(2)
val h = testRelation3.output(3)

// Case: join itself can resolve all the missing attributes
val plan = testRelation2.join(testRelation3)
.where('a > "str").select('a, 'b)
.sortBy('c.desc, 'h.asc)
val expected = testRelation2.join(testRelation3)
.where(a > "str").select(a, b, c, h)
.sortBy(c.desc, h.asc)
.select(a, b)
checkAnalysis(plan, expected)
}

test("resolve sort references - aggregate") {
val a = testRelation2.output(0)
val b = testRelation2.output(1)
val c = testRelation2.output(2)
val alias_a3 = count(a).as("a3")
val alias_b = b.as("aggOrder")

// Case 1: when the child of Sort is not Aggregate,
// the sort reference is handled by the rule ResolveSortReferences
val plan1 = testRelation2
.groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3"))
.select('a, 'c, 'a3)
.orderBy('b.asc)

val expected1 = testRelation2
.groupBy(a, c, b)(a, c, alias_a3, b)
.select(a, c, alias_a3.toAttribute, b)
.orderBy(b.asc)
.select(a, c, alias_a3.toAttribute)

checkAnalysis(plan1, expected1)

// Case 2: when the child of Sort is Aggregate,
// the sort reference is handled by the rule ResolveAggregateFunctions
val plan2 = testRelation2
.groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3"))
.orderBy('b.asc)

val expected2 = testRelation2
.groupBy(a, c, b)(a, c, alias_a3, alias_b)
.orderBy(alias_b.toAttribute.asc)
.select(a, c, alias_a3.toAttribute)

checkAnalysis(plan2, expected2)
}

test("resolve relations") {
assertAnalysisError(
UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq("Table not found: tAbLe"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ object TestRelations {
AttributeReference("d", DecimalType(10, 2))(),
AttributeReference("e", ShortType)())

val testRelation3 = LocalRelation(
AttributeReference("e", ShortType)(),
AttributeReference("f", StringType)(),
AttributeReference("g", DoubleType)(),
AttributeReference("h", DecimalType(10, 2))())

val nestedRelation = LocalRelation(
AttributeReference("top", StructType(
StructField("duplicateField", StringType) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil)
}

test("join - sorted columns not in join's outputSet") {
val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str_sort").as('df1)
val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df2)
val df3 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df3)

checkAnswer(
df.join(df2, $"df1.int" === $"df2.int", "outer").select($"df1.int", $"df2.int2")
.orderBy('str_sort.asc, 'str.asc),
Row(null, 6) :: Row(1, 3) :: Row(3, null) :: Nil)

checkAnswer(
df2.join(df3, $"df2.int" === $"df3.int", "inner")
.select($"df2.int", $"df3.int").orderBy($"df2.str".desc),
Row(5, 5) :: Row(1, 1) :: Nil)
}

test("join - join using multiple columns and specifying join type") {
val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str")
val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(expected === actual)
}

test("Sorting columns are not in Filter and Project") {
checkAnswer(
upperCaseData.filter('N > 1).select('N).filter('N < 6).orderBy('L.asc),
Row(2) :: Row(3) :: Row(4) :: Row(5) :: Nil)
}

test("SPARK-9323: DataFrame.orderBy should support nested column name") {
val df = sqlContext.read.json(sparkContext.makeRDD(
"""{"a": {"b": 1}}""" :: Nil))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
""".stripMargin), (2 to 6).map(i => Row(i)))
}

test("window function: udaf with aggregate expressin") {
test("window function: udaf with aggregate expression") {
val data = Seq(
WindowData(1, "a", 5),
WindowData(2, "a", 6),
Expand Down Expand Up @@ -966,6 +966,88 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
).map(i => Row(i._1, i._2, i._3, i._4)))
}

test("window function: Sorting columns are not in Project") {
val data = Seq(
WindowData(1, "d", 10),
WindowData(2, "a", 6),
WindowData(3, "b", 7),
WindowData(4, "b", 8),
WindowData(5, "c", 9),
WindowData(6, "c", 11)
)
sparkContext.parallelize(data).toDF().registerTempTable("windowData")

checkAnswer(
sql("select month, product, sum(product + 1) over() from windowData order by area"),
Seq(
(2, 6, 57),
(3, 7, 57),
(4, 8, 57),
(5, 9, 57),
(6, 11, 57),
(1, 10, 57)
).map(i => Row(i._1, i._2, i._3)))

checkAnswer(
sql(
"""
|select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1
|from (select month, area, product as p, 1 as tmp1 from windowData) tmp order by p
""".stripMargin),
Seq(
("a", 2),
("b", 2),
("b", 3),
("c", 2),
("d", 2),
("c", 3)
).map(i => Row(i._1, i._2)))

checkAnswer(
sql(
"""
|select area, rank() over (partition by area order by month) as c1
|from windowData group by product, area, month order by product, area
""".stripMargin),
Seq(
("a", 1),
("b", 1),
("b", 2),
("c", 1),
("d", 1),
("c", 2)
).map(i => Row(i._1, i._2)))
}

// todo: fix this test case by reimplementing the function ResolveAggregateFunctions
ignore("window function: Pushing aggregate Expressions in Sort to Aggregate") {
val data = Seq(
WindowData(1, "d", 10),
WindowData(2, "a", 6),
WindowData(3, "b", 7),
WindowData(4, "b", 8),
WindowData(5, "c", 9),
WindowData(6, "c", 11)
)
sparkContext.parallelize(data).toDF().registerTempTable("windowData")

checkAnswer(
sql(
"""
|select area, sum(product) over () as c from windowData
|where product > 3 group by area, product
|having avg(month) > 0 order by avg(month), product
""".stripMargin),
Seq(
("a", 51),
("b", 51),
("b", 51),
("c", 51),
("c", 51),
("d", 51)
).map(i => Row(i._1, i._2)))
}

test("window function: multiple window expressions in a single expression") {
val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
nums.registerTempTable("nums")
Expand Down