Skip to content

Commit

Permalink
address comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
gatorsmile committed Jan 20, 2016
1 parent 598a673 commit 1964884
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
package org.apache.spark.sql.catalyst.analysis

import scala.annotation.tailrec
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
Expand Down Expand Up @@ -76,7 +74,6 @@ class Analyzer(
ResolveGroupingAnalytics ::
ResolvePivot ::
ResolveUpCast ::
ResolveAggregateFunctions ::
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
Expand All @@ -85,6 +82,7 @@ class Analyzer(
ResolveWindowFrame ::
ExtractWindowExpressions ::
GlobalAggregates ::
ResolveAggregateFunctions ::
DistinctAggregationRewriter(conf) ::
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Expand Down Expand Up @@ -524,30 +522,31 @@ class Analyzer(
*/
object ResolveSortReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case s @ Sort(_, _, child) if !s.resolved && child.resolved =>
val (newOrdering, missingResolvableAttrs) =
collectResolvableMissingAttrs(s.order, plans = mutable.Queue(child))
// Here, this rule only resolves the missing sort references if the child is not Aggregate
// Another rule ResolveAggregateFunctions will resolve that case.
case s @ Sort(_, _, child)
if !s.resolved && child.resolved && !child.isInstanceOf[Aggregate] =>
val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child)

if (missingResolvableAttrs.isEmpty) {
val unresolvableAttrs = s.order.filterNot(_.resolved)
logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}")
s // Nothing we can do here. Return original plan.
}
else {
// Transform the whole tree in post-order. Add into the self's outputSet
// all the children attributes that are part of missingResolvableAttrs.
// Assumption: all the conflicting attributes between left and right have been resolved
} else {
// Add the missing attributes into projectList of Project/Window or
// aggregateExpressions of Aggregate, if they are in the inputSet
// but not in the outputSet of the plan.
val newChild = child transformUp {
case p: Project =>
p.copy(projectList = p.projectList ++
findNotResolvedMissingAttrs(p.outputSet, p.inputSet, missingResolvableAttrs))
missingResolvableAttrs.filter((p.inputSet -- p.outputSet).contains))
case w: Window =>
w.copy(projectList = w.projectList ++
findNotResolvedMissingAttrs(w.outputSet, w.inputSet, missingResolvableAttrs))
missingResolvableAttrs.filter((w.inputSet -- w.outputSet).contains))
case a: Aggregate =>
val newAggregateExpressions = a.aggregateExpressions ++
findNotResolvedMissingAttrs(
a.aggregateExpressions, a.inputSet, missingResolvableAttrs)
val resolvableAttrs = missingResolvableAttrs.filter(a.groupingExpressions.contains)
val notResolvedAttrs = resolvableAttrs.filterNot(a.aggregateExpressions.contains)
val newAggregateExpressions = a.aggregateExpressions ++ notResolvedAttrs
a.copy(aggregateExpressions = newAggregateExpressions)
case o => o
}
Expand All @@ -558,86 +557,43 @@ class Analyzer(
}
}

private def findNotResolvedMissingAttrs(
outputSet: AttributeSet,
inputSet: AttributeSet,
missingAttrs: Seq[Attribute]): Seq[Attribute] = {
val resolvedAttrs =
missingAttrs.filter(attr => inputSet.exists(_.semanticEquals(attr)))
resolvedAttrs.filterNot(attr => outputSet.exists(_.semanticEquals(attr)))
}

private def findNotResolvedMissingAttrs(
outputSet: Seq[Expression],
inputSet: AttributeSet,
missingAttrs: Seq[Attribute]): Seq[Attribute] = {
val resolvedAttrs =
missingAttrs.filter(attr => inputSet.exists(_.semanticEquals(attr)))
resolvedAttrs.filterNot(attr => outputSet.exists(_.semanticEquals(attr)))
}

/**
* Traverse the tree until resolving the sorting attributes
* Return all the resolvable missing sorting attributes
*/
@tailrec
private def collectResolvableMissingAttrs(
ordering: Seq[SortOrder],
plans: mutable.Queue[LogicalPlan]): (Seq[SortOrder], Seq[Attribute]) = {
if (plans.isEmpty) (Seq.empty[SortOrder], Seq.empty[Attribute])
else {
plans.dequeue() match {
// Only Windows and Project have projectList-like attribute.
case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] =>
val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child)
// If missingAttrs is non empty, that means we got it and return it;
// Otherwise, continue to traverse the tree.
if (missingAttrs.nonEmpty) (newOrdering, missingAttrs)
else {
plans.enqueue(un.child)
collectResolvableMissingAttrs(ordering, plans)
}
// Jump over the following UnaryNode types
// The output of these types is the same as their child's output
case un: UnaryNode
if un.isInstanceOf[Distinct] ||
un.isInstanceOf[Filter] ||
un.isInstanceOf[Limit] ||
un.isInstanceOf[RedistributeData] ||
un.isInstanceOf[Repartition] ||
un.isInstanceOf[RepartitionByExpression] ||
un.isInstanceOf[Sample] ||
un.isInstanceOf[Sort] ||
un.isInstanceOf[SortPartitions] ||
un.isInstanceOf[Subquery] ||
un.isInstanceOf[With] ||
un.isInstanceOf[WithWindowDefinition] =>
assert(un.inputSet == un.outputSet)
plans.enqueue(un.child)
collectResolvableMissingAttrs(ordering, plans)
case a: Aggregate =>
val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child)
// For Aggregate, all the order by columns must be specified in group by clauses
if (missingAttrs.nonEmpty &&
missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) {
(newOrdering, missingAttrs)
}
// If missingAttrs is empty, do not traverse its child and just try the next
else collectResolvableMissingAttrs(ordering, plans)
case join @ Join(left, right, joinType, _) =>
joinType match {
case _ @ (Inner | LeftOuter | RightOuter | FullOuter) =>
plans.enqueue(left, right)
collectResolvableMissingAttrs(ordering, plans)
// If we support LeftAnti, we should add it here
case _ @ LeftSemi =>
plans.enqueue(left)
collectResolvableMissingAttrs(ordering, plans)
}
// If hitting the other unsupported operators, we are unable to resolve it,
// try the next until the queue is empty
case other => collectResolvableMissingAttrs(ordering, plans)
}
plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
plan match {
// Only Windows and Project have projectList-like attribute.
case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] =>
val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child)
// If missingAttrs is non empty, that means we got it and return it;
// Otherwise, continue to traverse the tree.
if (missingAttrs.nonEmpty) {
(newOrdering, missingAttrs)
} else {
collectResolvableMissingAttrs(ordering, un.child)
}
case a: Aggregate =>
val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child)
// For Aggregate, all the order by columns must be specified in group by clauses
if (missingAttrs.nonEmpty &&
missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) {
(newOrdering, missingAttrs)
} else {
// If missingAttrs is empty, we are unable to resolve any unresolved missing attributes
(Seq.empty[SortOrder], Seq.empty[Attribute])
}
// Jump over the following UnaryNode types
// The output of these types is the same as their child's output
case _: Distinct |
_: Filter |
_: RepartitionByExpression =>
collectResolvableMissingAttrs(ordering, plan.asInstanceOf[UnaryNode].child)
// If hitting the other unsupported operators, we are unable to resolve it.
case other => (Seq.empty[SortOrder], Seq.empty[Attribute])
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,10 @@ class AnalysisSuite extends AnalysisTest {
val plan1 = testRelation2
.where('a > "str").select('a, 'b)
.where('b > "str").select('a)
.limit(4)
.sortBy('b.asc, 'c.desc)
val expected1 = testRelation2
.where(a > "str").select(a, b, c)
.where(b > "str").select(a, b, c)
.limit(4)
.sortBy(b.asc, c.desc)
.select(a, b).select(a)
checkAnalysis(plan1, expected1)
Expand All @@ -112,80 +110,53 @@ class AnalysisSuite extends AnalysisTest {
val a = testRelation2.output(0)
val b = testRelation2.output(1)
val c = testRelation2.output(2)

val f = testRelation3.output(1)
val h = testRelation3.output(3)

// Case 1: join itself can resolve all the missing attributes
val plan1 = testRelation2.join(testRelation3)
.where('a > "str").select('a, 'b)
.sortBy('c.desc, 'h.asc)
val expected1 = testRelation2.join(testRelation3)
.where(a > "str").select(a, b, c, h)
.sortBy(c.desc, h.asc)
.select(a, b)
checkAnalysis(plan1, expected1)

// Case 2: join itself can resolve partial missing attributes
// and the remaining ones are resolvable in its left tree
val plan2 = testRelation2.select('a, 'b).join(testRelation3)
.where('a > "str").select('a, 'b)
.sortBy('c.desc, 'h.asc)
val expected2 = testRelation2.select(a, b, c).join(testRelation3)
.where(a > "str").select(a, b, h, c)
.sortBy(c.desc, h.asc)
.select(a, b, h).select(a, b)
checkAnalysis(plan2, expected2)

// Case 3: both trees are needed to resolve all the missing attribute
val plan3 = testRelation2.select('a, 'b).join(testRelation3.select('f))
// Case: join itself can resolve all the missing attributes
val plan = testRelation2.join(testRelation3)
.where('a > "str").select('a, 'b)
.sortBy('c.desc, 'h.asc)
val expected3 = testRelation2.select(a, b, c).join(testRelation3.select(f, h))
val expected = testRelation2.join(testRelation3)
.where(a > "str").select(a, b, c, h)
.sortBy(c.desc, h.asc)
.select(a, b, c).select(a, b)
checkAnalysis(plan3, expected3)

// Case 4: right tree does not resolve any missing attribute
val plan4 = testRelation2.select('a, 'b).join(testRelation3.select('f))
.where('a > "str").select('a, 'b)
.sortBy('h.asc)
val expected4 = testRelation2.select(a, b).join(testRelation3.select(f, h))
.where(a > "str").select(a, b, h)
.sortBy(h.asc)
.select(a, b)
checkAnalysis(plan4, expected4)

// Case 5: left tree does not resolve any missing attribute
val plan5 = testRelation2.select('a, 'b).join(testRelation3.select('f))
.where('a > "str").select('a, 'b)
.sortBy('c.desc)
val expected5 = testRelation2.select(a, b, c).join(testRelation3.select(f))
.where(a > "str").select(a, b, c)
.sortBy(c.desc)
.select(a, b)
checkAnalysis(plan5, expected5)
checkAnalysis(plan, expected)
}

test("resolve sort references - aggregate") {
val a = testRelation2.output(0)
val b = testRelation2.output(1)
val c = testRelation2.output(2)
val alias3 = count(a).as("a3")
val alias_a3 = count(a).as("a3")
val alias_b = b.as("aggOrder")

val plan = testRelation2
// Case 1: when the child of Sort is not Aggregate,
// the sort reference is handled by the rule ResolveSortReferences
val plan1 = testRelation2
.groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3"))
.select('a, 'c, 'a3)
.orderBy('b.asc)

val expected = testRelation2
.groupBy(a, c, b)(a, c, alias3, b)
.select(a, c, alias3.toAttribute, b)
val expected1 = testRelation2
.groupBy(a, c, b)(a, c, alias_a3, b)
.select(a, c, alias_a3.toAttribute, b)
.orderBy(b.asc)
.select(a, c, alias3.toAttribute)
.select(a, c, alias_a3.toAttribute)

checkAnalysis(plan, expected)
checkAnalysis(plan1, expected1)

// Case 2: when the child of Sort is Aggregate,
// the sort reference is handled by the rule ResolveAggregateFunctions
val plan2 = testRelation2
.groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3"))
.orderBy('b.asc)

val expected2 = testRelation2
.groupBy(a, c, b)(a, c, alias_a3, alias_b)
.orderBy(alias_b.toAttribute.asc)
.select(a, c, alias_a3.toAttribute)

checkAnalysis(plan2, expected2)
}

test("resolve relations") {
Expand Down

0 comments on commit 1964884

Please sign in to comment.