Skip to content

Commit

Permalink
[SPARK-12705][SPARK-10777][SQL] Analyzer Rule ResolveSortReferences
Browse files Browse the repository at this point in the history
JIRA: https://issues.apache.org/jira/browse/SPARK-12705

**Scope:**
This PR is a general fix for sorting reference resolution when the child's `outputSet` does not have the order-by attributes (called, *missing attributes*):
  - UnaryNode support is limited to `Project`, `Window`, `Aggregate`, `Distinct`, `Filter`, `RepartitionByExpression`.
  - We will not try to resolve the missing references inside a subquery, unless the outputSet of this subquery contains it.

**General Reference Resolution Rules:**
  - Jump over the nodes with the following types: `Distinct`, `Filter`, `RepartitionByExpression`. Do not need to add missing attributes. The reason is their `outputSet` is decided by their `inputSet`, which is the `outputSet` of their children.
  - Group-by expressions in `Aggregate`: missing order-by attributes are not allowed to be added into group-by expressions since it will change the query result. Thus, in RDBMS, it is not allowed.
  - Aggregate expressions in `Aggregate`: if the group-by expressions in `Aggregate` contains the missing attributes but aggregate expressions do not have it, just add them into the aggregate expressions. This can resolve the analysisExceptions thrown by the three TCPDS queries.
  - `Project` and `Window` are special. We just need to add the missing attributes to their `projectList`.

**Implementation:**
  1. Traverse the whole tree in a pre-order manner to find all the resolvable missing order-by attributes.
  2. Traverse the whole tree in a post-order manner to add the found missing order-by attributes to the node if their `inputSet` contains the attributes.
  3. If the origins of the missing order-by attributes are different nodes, each pass only resolves the missing attributes that are from the same node.

**Risk:**
Low. This rule will be trigger iff ```!s.resolved && child.resolved``` is true. Thus, very few cases are affected.

Author: gatorsmile <[email protected]>

Closes #10678 from gatorsmile/sortWindows.
  • Loading branch information
gatorsmile authored and marmbrus committed Feb 1, 2016
1 parent 33c8a49 commit 8f26eb5
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.analysis

import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.AnalysisException
Expand Down Expand Up @@ -452,7 +453,7 @@ class Analyzer(
i.copy(right = dedupRight(left, right))

// When resolve `SortOrder`s in Sort based on child, don't report errors as
// we still have chance to resolve it based on grandchild
// we still have chance to resolve it based on its descendants
case s @ Sort(ordering, global, child) if child.resolved && !s.resolved =>
val newOrdering = resolveSortOrders(ordering, child, throws = false)
Sort(newOrdering, global, child)
Expand Down Expand Up @@ -533,38 +534,96 @@ class Analyzer(
*/
object ResolveSortReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case s @ Sort(ordering, global, p @ Project(projectList, child))
if !s.resolved && p.resolved =>
val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child)
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, child: Aggregate) => sa

// If this rule was not a no-op, return the transformed plan, otherwise return the original.
if (missing.nonEmpty) {
// Add missing attributes and then project them away after the sort.
Project(p.output,
Sort(newOrdering, global,
Project(projectList ++ missing, child)))
} else {
logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}")
case s @ Sort(_, _, child) if !s.resolved && child.resolved =>
val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child)

if (missingResolvableAttrs.isEmpty) {
val unresolvableAttrs = s.order.filterNot(_.resolved)
logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}")
s // Nothing we can do here. Return original plan.
} else {
// Add the missing attributes into projectList of Project/Window or
// aggregateExpressions of Aggregate, if they are in the inputSet
// but not in the outputSet of the plan.
val newChild = child transformUp {
case p: Project =>
p.copy(projectList = p.projectList ++
missingResolvableAttrs.filter((p.inputSet -- p.outputSet).contains))
case w: Window =>
w.copy(projectList = w.projectList ++
missingResolvableAttrs.filter((w.inputSet -- w.outputSet).contains))
case a: Aggregate =>
val resolvableAttrs = missingResolvableAttrs.filter(a.groupingExpressions.contains)
val notResolvedAttrs = resolvableAttrs.filterNot(a.aggregateExpressions.contains)
val newAggregateExpressions = a.aggregateExpressions ++ notResolvedAttrs
a.copy(aggregateExpressions = newAggregateExpressions)
case o => o
}

// Add missing attributes and then project them away after the sort.
Project(child.output,
Sort(newOrdering, s.global, newChild))
}
}

/**
* Given a child and a grandchild that are present beneath a sort operator, try to resolve
* the sort ordering and returns it with a list of attributes that are missing from the
* child but are present in the grandchild.
* Traverse the tree until resolving the sorting attributes
* Return all the resolvable missing sorting attributes
*/
@tailrec
private def collectResolvableMissingAttrs(
ordering: Seq[SortOrder],
plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
plan match {
// Only Windows and Project have projectList-like attribute.
case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] =>
val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child)
// If missingAttrs is non empty, that means we got it and return it;
// Otherwise, continue to traverse the tree.
if (missingAttrs.nonEmpty) {
(newOrdering, missingAttrs)
} else {
collectResolvableMissingAttrs(ordering, un.child)
}
case a: Aggregate =>
val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child)
// For Aggregate, all the order by columns must be specified in group by clauses
if (missingAttrs.nonEmpty &&
missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) {
(newOrdering, missingAttrs)
} else {
// If missingAttrs is empty, we are unable to resolve any unresolved missing attributes
(Seq.empty[SortOrder], Seq.empty[Attribute])
}
// Jump over the following UnaryNode types
// The output of these types is the same as their child's output
case _: Distinct |
_: Filter |
_: RepartitionByExpression =>
collectResolvableMissingAttrs(ordering, plan.asInstanceOf[UnaryNode].child)
// If hitting the other unsupported operators, we are unable to resolve it.
case other => (Seq.empty[SortOrder], Seq.empty[Attribute])
}
}

/**
* Try to resolve the sort ordering and returns it with a list of attributes that are missing
* from the plan but are present in the child.
*/
def resolveAndFindMissing(
private def resolveAndFindMissing(
ordering: Seq[SortOrder],
child: LogicalPlan,
grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
val newOrdering = resolveSortOrders(ordering, grandchild, throws = true)
plan: LogicalPlan,
child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
val newOrdering = resolveSortOrders(ordering, child, throws = false)
// Construct a set that contains all of the attributes that we need to evaluate the
// ordering.
val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved)
// Figure out which ones are missing from the projection, so that we can add them and
// remove them after the sort.
val missingInProject = requiredAttributes -- child.output
val missingInProject = requiredAttributes -- plan.outputSet
// It is important to return the new SortOrders here, instead of waiting for the standard
// resolving process as adding attributes to the project below can actually introduce
// ambiguity that was not present before.
Expand Down Expand Up @@ -719,7 +778,7 @@ class Analyzer(
}
}

protected def containsAggregate(condition: Expression): Boolean = {
def containsAggregate(condition: Expression): Boolean = {
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,89 @@ class AnalysisSuite extends AnalysisTest {
caseSensitive = false)
}

test("resolve sort references - filter/limit") {
val a = testRelation2.output(0)
val b = testRelation2.output(1)
val c = testRelation2.output(2)

// Case 1: one missing attribute is in the leaf node and another is in the unary node
val plan1 = testRelation2
.where('a > "str").select('a, 'b)
.where('b > "str").select('a)
.sortBy('b.asc, 'c.desc)
val expected1 = testRelation2
.where(a > "str").select(a, b, c)
.where(b > "str").select(a, b, c)
.sortBy(b.asc, c.desc)
.select(a, b).select(a)
checkAnalysis(plan1, expected1)

// Case 2: all the missing attributes are in the leaf node
val plan2 = testRelation2
.where('a > "str").select('a)
.where('a > "str").select('a)
.sortBy('b.asc, 'c.desc)
val expected2 = testRelation2
.where(a > "str").select(a, b, c)
.where(a > "str").select(a, b, c)
.sortBy(b.asc, c.desc)
.select(a)
checkAnalysis(plan2, expected2)
}

test("resolve sort references - join") {
val a = testRelation2.output(0)
val b = testRelation2.output(1)
val c = testRelation2.output(2)
val h = testRelation3.output(3)

// Case: join itself can resolve all the missing attributes
val plan = testRelation2.join(testRelation3)
.where('a > "str").select('a, 'b)
.sortBy('c.desc, 'h.asc)
val expected = testRelation2.join(testRelation3)
.where(a > "str").select(a, b, c, h)
.sortBy(c.desc, h.asc)
.select(a, b)
checkAnalysis(plan, expected)
}

test("resolve sort references - aggregate") {
val a = testRelation2.output(0)
val b = testRelation2.output(1)
val c = testRelation2.output(2)
val alias_a3 = count(a).as("a3")
val alias_b = b.as("aggOrder")

// Case 1: when the child of Sort is not Aggregate,
// the sort reference is handled by the rule ResolveSortReferences
val plan1 = testRelation2
.groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3"))
.select('a, 'c, 'a3)
.orderBy('b.asc)

val expected1 = testRelation2
.groupBy(a, c, b)(a, c, alias_a3, b)
.select(a, c, alias_a3.toAttribute, b)
.orderBy(b.asc)
.select(a, c, alias_a3.toAttribute)

checkAnalysis(plan1, expected1)

// Case 2: when the child of Sort is Aggregate,
// the sort reference is handled by the rule ResolveAggregateFunctions
val plan2 = testRelation2
.groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3"))
.orderBy('b.asc)

val expected2 = testRelation2
.groupBy(a, c, b)(a, c, alias_a3, alias_b)
.orderBy(alias_b.toAttribute.asc)
.select(a, c, alias_a3.toAttribute)

checkAnalysis(plan2, expected2)
}

test("resolve relations") {
assertAnalysisError(
UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq("Table not found: tAbLe"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ object TestRelations {
AttributeReference("d", DecimalType(10, 2))(),
AttributeReference("e", ShortType)())

val testRelation3 = LocalRelation(
AttributeReference("e", ShortType)(),
AttributeReference("f", StringType)(),
AttributeReference("g", DoubleType)(),
AttributeReference("h", DecimalType(10, 2))())

val nestedRelation = LocalRelation(
AttributeReference("top", StructType(
StructField("duplicateField", StringType) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil)
}

test("join - sorted columns not in join's outputSet") {
val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str_sort").as('df1)
val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df2)
val df3 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df3)

checkAnswer(
df.join(df2, $"df1.int" === $"df2.int", "outer").select($"df1.int", $"df2.int2")
.orderBy('str_sort.asc, 'str.asc),
Row(null, 6) :: Row(1, 3) :: Row(3, null) :: Nil)

checkAnswer(
df2.join(df3, $"df2.int" === $"df3.int", "inner")
.select($"df2.int", $"df3.int").orderBy($"df2.str".desc),
Row(5, 5) :: Row(1, 1) :: Nil)
}

test("join - join using multiple columns and specifying join type") {
val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str")
val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(expected === actual)
}

test("Sorting columns are not in Filter and Project") {
checkAnswer(
upperCaseData.filter('N > 1).select('N).filter('N < 6).orderBy('L.asc),
Row(2) :: Row(3) :: Row(4) :: Row(5) :: Nil)
}

test("SPARK-9323: DataFrame.orderBy should support nested column name") {
val df = sqlContext.read.json(sparkContext.makeRDD(
"""{"a": {"b": 1}}""" :: Nil))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
""".stripMargin), (2 to 6).map(i => Row(i)))
}

test("window function: udaf with aggregate expressin") {
test("window function: udaf with aggregate expression") {
val data = Seq(
WindowData(1, "a", 5),
WindowData(2, "a", 6),
Expand Down Expand Up @@ -927,6 +927,88 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
).map(i => Row(i._1, i._2, i._3, i._4)))
}

test("window function: Sorting columns are not in Project") {
val data = Seq(
WindowData(1, "d", 10),
WindowData(2, "a", 6),
WindowData(3, "b", 7),
WindowData(4, "b", 8),
WindowData(5, "c", 9),
WindowData(6, "c", 11)
)
sparkContext.parallelize(data).toDF().registerTempTable("windowData")

checkAnswer(
sql("select month, product, sum(product + 1) over() from windowData order by area"),
Seq(
(2, 6, 57),
(3, 7, 57),
(4, 8, 57),
(5, 9, 57),
(6, 11, 57),
(1, 10, 57)
).map(i => Row(i._1, i._2, i._3)))

checkAnswer(
sql(
"""
|select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1
|from (select month, area, product as p, 1 as tmp1 from windowData) tmp order by p
""".stripMargin),
Seq(
("a", 2),
("b", 2),
("b", 3),
("c", 2),
("d", 2),
("c", 3)
).map(i => Row(i._1, i._2)))

checkAnswer(
sql(
"""
|select area, rank() over (partition by area order by month) as c1
|from windowData group by product, area, month order by product, area
""".stripMargin),
Seq(
("a", 1),
("b", 1),
("b", 2),
("c", 1),
("d", 1),
("c", 2)
).map(i => Row(i._1, i._2)))
}

// todo: fix this test case by reimplementing the function ResolveAggregateFunctions
ignore("window function: Pushing aggregate Expressions in Sort to Aggregate") {
val data = Seq(
WindowData(1, "d", 10),
WindowData(2, "a", 6),
WindowData(3, "b", 7),
WindowData(4, "b", 8),
WindowData(5, "c", 9),
WindowData(6, "c", 11)
)
sparkContext.parallelize(data).toDF().registerTempTable("windowData")

checkAnswer(
sql(
"""
|select area, sum(product) over () as c from windowData
|where product > 3 group by area, product
|having avg(month) > 0 order by avg(month), product
""".stripMargin),
Seq(
("a", 51),
("b", 51),
("b", 51),
("c", 51),
("c", 51),
("d", 51)
).map(i => Row(i._1, i._2)))
}

test("window function: multiple window expressions in a single expression") {
val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
nums.registerTempTable("nums")
Expand Down

0 comments on commit 8f26eb5

Please sign in to comment.