-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-12616] [SQL] Making Logical Operator Union
Support Arbitrary Number of Children
#10577
Changes from 2 commits
01e4cdf
6835704
9180687
b38a21e
d2b84af
fda8025
ac0dccd
6e0018b
0546772
b37a64f
73270c8
d9811c7
5d031a7
c1f66f7
c1dcd02
51ad5b2
c2a872c
5681ca8
7a54c8f
95e2349
6a6003e
15ec058
5e06647
ab6dbd7
b821af0
2229932
b3327b1
4276356
2dab708
723c0da
42b81a8
4e0387f
b03d813
ab732c1
0458770
1debdfa
7320e21
741371a
031a5d8
f3d23dc
a56e595
abfcf93
763706d
e8e19a1
3b13ddf
b88bdeb
3041864
6259fd9
4de6ec1
f112026
4f71741
9422a4f
59b5895
c63f237
2e8562d
a571998
52bdf48
c18381e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { | |
Batch("Aggregate", FixedPoint(100), | ||
ReplaceDistinctWithAggregate, | ||
RemoveLiteralFromGroupExpressions) :: | ||
Batch("Unions", FixedPoint(100), | ||
CombineUnions) :: | ||
Batch("Operator Optimizations", FixedPoint(100), | ||
// Operator push down | ||
SetOperationPushDown, | ||
|
@@ -54,6 +56,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { | |
ProjectCollapsing, | ||
CombineFilters, | ||
CombineLimits, | ||
CombineUnions, | ||
// Constant folding | ||
NullPropagation, | ||
OptimizeIn, | ||
|
@@ -594,6 +597,22 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { | |
} | ||
} | ||
|
||
/** | ||
* Combines all adjacent [[Union]] and [[Unions]] operators into a single [[Unions]]. | ||
*/ | ||
object CombineUnions extends Rule[LogicalPlan] { | ||
private def collectUnionChildren(plan: LogicalPlan): Seq[LogicalPlan] = plan match { | ||
case Union(l, r) => collectUnionChildren(l) ++ collectUnionChildren(r) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you should write this without using recursion to avoid stack overflow. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Removing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another option would just be to do this at construction time, that way we can avoid paying the cost in the analyzer. This would still limit the cases we could cache (i.e. we'd miss cached data unioned with other data), but that doesn't seem like a huge deal. I'd leave this rule here either way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @marmbrus Could I ask you a question regarding your comment here? I don't understand
|
||
case Unions(children) => children.flatMap(collectUnionChildren) | ||
case other => other :: Nil | ||
} | ||
|
||
def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { | ||
case u: Union => Unions(collectUnionChildren(u)) | ||
case u: Unions => Unions(collectUnionChildren(u)) | ||
} | ||
} | ||
|
||
/** | ||
* Combines two adjacent [[Filter]] operators into one, merging the | ||
* conditions into one conjunctive predicate. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -107,6 +107,13 @@ private[sql] object SetOperation { | |
def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) | ||
} | ||
|
||
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) | ||
|
||
case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { | ||
/** We don't use right.output because those rows get excluded from the set. */ | ||
override def output: Seq[Attribute] = left.output | ||
} | ||
|
||
case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { | ||
|
||
override def statistics: Statistics = { | ||
|
@@ -115,11 +122,20 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(lef | |
} | ||
} | ||
|
||
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) | ||
case class Unions(children: Seq[LogicalPlan]) extends LogicalPlan { | ||
|
||
case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { | ||
/** We don't use right.output because those rows get excluded from the set. */ | ||
override def output: Seq[Attribute] = left.output | ||
override def output: Seq[Attribute] = { | ||
children.tail.foldLeft(children.head.output) { case (currentOutput, child) => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd also add some comment here explaining what's going on. at a high level, you are just updating nullability right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah |
||
currentOutput.zip(child.output).map { case (a1, a2) => | ||
a1.withNullability(a1.nullable || a2.nullable) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you! |
||
} | ||
} | ||
|
||
override def statistics: Statistics = { | ||
val sizeInBytes = children.map(_.statistics.sizeInBytes).sum | ||
Statistics(sizeInBytes = sizeInBytes) | ||
} | ||
} | ||
|
||
case class Join( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,9 @@ class SetOperationPushDownSuite extends PlanTest { | |
EliminateSubQueries) :: | ||
Batch("Union Pushdown", Once, | ||
SetOperationPushDown, | ||
SimplifyFilters) :: Nil | ||
SimplifyFilters) :: | ||
Batch("Unions", Once, | ||
CombineUnions) :: Nil | ||
} | ||
|
||
val testRelation = LocalRelation('a.int, 'b.int, 'c.int) | ||
|
@@ -40,6 +42,20 @@ class SetOperationPushDownSuite extends PlanTest { | |
val testIntersect = Intersect(testRelation, testRelation2) | ||
val testExcept = Except(testRelation, testRelation2) | ||
|
||
test("union: combine unions into one unions") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's not kind of push-down test, how about we create a new test suite for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, maybe just rename this test suite to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, sure, will rename it. |
||
val unionQuery1 = Union(Union(testRelation, testRelation2), testRelation) | ||
val unionQuery2 = Union(testRelation, Union(testRelation2, testRelation)) | ||
val unionOptimized1 = Optimize.execute(unionQuery1.analyze) | ||
val unionOptimized2 = Optimize.execute(unionQuery2.analyze) | ||
comparePlans(unionOptimized1, unionOptimized2) | ||
|
||
val combinedUnions = Unions(unionOptimized1 :: unionOptimized2 :: Nil) | ||
val combinedUnionsOptimized = Optimize.execute(combinedUnions.analyze) | ||
val unionQuery3 = Union(unionQuery1, unionQuery2) | ||
val unionOptimized3 = Optimize.execute(unionQuery3.analyze) | ||
comparePlans(combinedUnionsOptimized, unionOptimized3) | ||
} | ||
|
||
test("union/intersect/except: filter to each side") { | ||
val unionQuery = testUnion.where('a === 1) | ||
val intersectQuery = testIntersect.where('b < 10) | ||
|
@@ -50,7 +66,7 @@ class SetOperationPushDownSuite extends PlanTest { | |
val exceptOptimized = Optimize.execute(exceptQuery.analyze) | ||
|
||
val unionCorrectAnswer = | ||
Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze | ||
Unions(testRelation.where('a === 1) :: testRelation2.where('d === 1) :: Nil).analyze | ||
val intersectCorrectAnswer = | ||
Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze | ||
val exceptCorrectAnswer = | ||
|
@@ -65,7 +81,7 @@ class SetOperationPushDownSuite extends PlanTest { | |
val unionQuery = testUnion.select('a) | ||
val unionOptimized = Optimize.execute(unionQuery.analyze) | ||
val unionCorrectAnswer = | ||
Union(testRelation.select('a), testRelation2.select('d)).analyze | ||
Unions(testRelation.select('a) :: testRelation2.select('d) :: Nil ).analyze | ||
comparePlans(unionOptimized, unionCorrectAnswer) | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,7 @@ import scala.util.Random | |
import org.scalatest.Matchers._ | ||
|
||
import org.apache.spark.SparkException | ||
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation | ||
import org.apache.spark.sql.catalyst.plans.logical.{Unions, OneRowRelation} | ||
import org.apache.spark.sql.execution.Exchange | ||
import org.apache.spark.sql.execution.aggregate.TungstenAggregate | ||
import org.apache.spark.sql.functions._ | ||
|
@@ -98,6 +98,19 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { | |
testData.collect().toSeq) | ||
} | ||
|
||
test ("union all") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the space after There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. : ) |
||
val unionsDF = testData.unionAll(testData).unionAll(testData) | ||
.unionAll(testData).unionAll(testData) | ||
|
||
assert(unionsDF.queryExecution.optimizedPlan.collect { | ||
case j @ Unions(Seq(_, _, _, _, _)) => j }.size === 1) | ||
|
||
checkAnswer( | ||
unionsDF.agg(avg('key), max('key), min('key), sum('key)), | ||
Row(50.5, 100, 1, 25250) :: Nil | ||
) | ||
} | ||
|
||
test("empty data frame") { | ||
assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) | ||
assert(sqlContext.emptyDataFrame.count() === 0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should explain in comments why CombineUnions appear twice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and maybe move this before aggregate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.