From b38a21ef6146784e4b93ef4ce8c899f1eee14572 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 16 Nov 2015 18:30:26 -0800 Subject: [PATCH 01/27] SPARK-11633 --- .../spark/sql/catalyst/analysis/Analyzer.scala | 3 ++- .../spark/sql/hive/execution/SQLQuerySuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2f4670b55bdba..5a5b71e52dd79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -425,7 +425,8 @@ class Analyzer( */ j case Some((oldRelation, newRelation)) => - val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) + val attributeRewrites = + AttributeMap(oldRelation.output.zip(newRelation.output).filter(x => x._1 != x._2)) val newRight = right transformUp { case r if r == oldRelation => newRelation } transformUp { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 3427152b2da02..5e00546a74c00 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -51,6 +51,8 @@ case class Order( state: String, month: Int) +case class Individual(F1: Integer, F2: Integer) + case class WindowData( month: Int, area: String, @@ -1479,4 +1481,18 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test """.stripMargin), Row("value1", "12", 3.14, "hello")) } + + test ("SPARK-11633: HiveContext throws TreeNode Exception : Failed to Copy Node") { + val rdd1 = sparkContext.parallelize(Seq( Individual(1,3), Individual(2,1))) + val df = hiveContext.createDataFrame(rdd1) + df.registerTempTable("foo") + val df2 = sql("select f1, F2 as F2 from foo") + df2.registerTempTable("foo2") + df2.registerTempTable("foo3") + + checkAnswer(sql( + """ + SELECT a.F1 FROM foo2 a INNER JOIN foo3 b ON a.F2=b.F2 + """.stripMargin), Row(2) :: Row(1) :: Nil) + } } From 0546772f151f83d6d3cf4d000cbe341f52545007 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 20 Nov 2015 10:56:45 -0800 Subject: [PATCH 02/27] converge --- .../spark/sql/catalyst/analysis/Analyzer.scala | 3 +-- .../spark/sql/hive/execution/SQLQuerySuite.scala | 15 --------------- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7c9512fbd00aa..47962ebe6ef82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -417,8 +417,7 @@ class Analyzer( */ j case Some((oldRelation, newRelation)) => - val attributeRewrites = - AttributeMap(oldRelation.output.zip(newRelation.output).filter(x => x._1 != x._2)) + val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) val newRight = right transformUp { case r if r == oldRelation => newRelation } transformUp { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 5e00546a74c00..61d9dcd37572b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -51,8 +51,6 @@ case class Order( state: String, month: Int) -case class Individual(F1: Integer, F2: Integer) - case class WindowData( month: Int, area: String, @@ -1481,18 +1479,5 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test """.stripMargin), Row("value1", "12", 3.14, "hello")) } - - test ("SPARK-11633: HiveContext throws TreeNode Exception : Failed to Copy Node") { - val rdd1 = sparkContext.parallelize(Seq( Individual(1,3), Individual(2,1))) - val df = hiveContext.createDataFrame(rdd1) - df.registerTempTable("foo") - val df2 = sql("select f1, F2 as F2 from foo") - df2.registerTempTable("foo2") - df2.registerTempTable("foo3") - - checkAnswer(sql( - """ - SELECT a.F1 FROM foo2 a INNER JOIN foo3 b ON a.F2=b.F2 - """.stripMargin), Row(2) :: Row(1) :: Nil) } } From b37a64f13956b6ddd0e38ddfd9fe1caee611f1a8 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 20 Nov 2015 10:58:37 -0800 Subject: [PATCH 03/27] converge --- .../org/apache/spark/sql/hive/execution/SQLQuerySuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 61d9dcd37572b..3427152b2da02 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1479,5 +1479,4 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test """.stripMargin), Row("value1", "12", 3.14, "hello")) } - } } From 73270c8aa7b7e387e7b0e75369dfcbf8c554aa5e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 4 Jan 2016 12:09:50 -0800 Subject: [PATCH 04/27] added a new logical operator UNIONS --- .../sql/catalyst/optimizer/Optimizer.scala | 19 +++++++++++++++ .../sql/catalyst/planning/patterns.scala | 15 ------------ .../plans/logical/basicOperators.scala | 24 +++++++++++++++---- .../optimizer/SetOperationPushDownSuite.scala | 22 ++++++++++++++--- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 15 +++++++++++- .../spark/sql/execution/PlannerSuite.scala | 12 ---------- 7 files changed, 73 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0b1c74293bb8b..5538d5cffc7ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -40,6 +40,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: + Batch("Unions", FixedPoint(100), + CombineUnions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down SetOperationPushDown, @@ -54,6 +56,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { ProjectCollapsing, CombineFilters, CombineLimits, + CombineUnions, // Constant folding NullPropagation, OptimizeIn, @@ -594,6 +597,22 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Combines all adjacent [[Union]] and [[Unions]] operators into a single [[Unions]]. + */ +object CombineUnions extends Rule[LogicalPlan] { + private def collectUnionChildren(plan: LogicalPlan): Seq[LogicalPlan] = plan match { + case Union(l, r) => collectUnionChildren(l) ++ collectUnionChildren(r) + case Unions(children) => children.flatMap(collectUnionChildren) + case other => other :: Nil + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case u: Union => Unions(collectUnionChildren(u)) + case u: Unions => Unions(collectUnionChildren(u)) + } +} + /** * Combines two adjacent [[Filter]] operators into one, merging the * conditions into one conjunctive predicate. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index cd3f15cbe107b..9be88ac1012b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -169,18 +169,3 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { case _ => None } } - -/** - * A pattern that collects all adjacent unions and returns their children as a Seq. - */ -object Unions { - def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match { - case u: Union => Some(collectUnionChildren(u)) - case _ => None - } - - private def collectUnionChildren(plan: LogicalPlan): Seq[LogicalPlan] = plan match { - case Union(l, r) => collectUnionChildren(l) ++ collectUnionChildren(r) - case other => other :: Nil - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 5f34d4a4eb73c..014a8fff9c1b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -107,6 +107,13 @@ private[sql] object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) } +case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) + +case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { + /** We don't use right.output because those rows get excluded from the set. */ + override def output: Seq[Attribute] = left.output +} + case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { override def statistics: Statistics = { @@ -115,11 +122,20 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(lef } } -case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) +case class Unions(children: Seq[LogicalPlan]) extends LogicalPlan { -case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { - /** We don't use right.output because those rows get excluded from the set. */ - override def output: Seq[Attribute] = left.output + override def output: Seq[Attribute] = { + children.tail.foldLeft(children.head.output) { case (currentOutput, child) => + currentOutput.zip(child.output).map { case (a1, a2) => + a1.withNullability(a1.nullable || a2.nullable) + } + } + } + + override def statistics: Statistics = { + val sizeInBytes = children.map(_.statistics.sizeInBytes).sum + Statistics(sizeInBytes = sizeInBytes) + } } case class Join( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index 1595ad9327423..65c8ec52ad478 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -31,7 +31,9 @@ class SetOperationPushDownSuite extends PlanTest { EliminateSubQueries) :: Batch("Union Pushdown", Once, SetOperationPushDown, - SimplifyFilters) :: Nil + SimplifyFilters) :: + Batch("Unions", Once, + CombineUnions) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -40,6 +42,20 @@ class SetOperationPushDownSuite extends PlanTest { val testIntersect = Intersect(testRelation, testRelation2) val testExcept = Except(testRelation, testRelation2) + test("union: combine unions into one unions") { + val unionQuery1 = Union(Union(testRelation, testRelation2), testRelation) + val unionQuery2 = Union(testRelation, Union(testRelation2, testRelation)) + val unionOptimized1 = Optimize.execute(unionQuery1.analyze) + val unionOptimized2 = Optimize.execute(unionQuery2.analyze) + comparePlans(unionOptimized1, unionOptimized2) + + val combinedUnions = Unions(unionOptimized1 :: unionOptimized2 :: Nil) + val combinedUnionsOptimized = Optimize.execute(combinedUnions.analyze) + val unionQuery3 = Union(unionQuery1, unionQuery2) + val unionOptimized3 = Optimize.execute(unionQuery3.analyze) + comparePlans(combinedUnionsOptimized, unionOptimized3) + } + test("union/intersect/except: filter to each side") { val unionQuery = testUnion.where('a === 1) val intersectQuery = testIntersect.where('b < 10) @@ -50,7 +66,7 @@ class SetOperationPushDownSuite extends PlanTest { val exceptOptimized = Optimize.execute(exceptQuery.analyze) val unionCorrectAnswer = - Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze + Unions(testRelation.where('a === 1) :: testRelation2.where('d === 1) :: Nil).analyze val intersectCorrectAnswer = Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze val exceptCorrectAnswer = @@ -65,7 +81,7 @@ class SetOperationPushDownSuite extends PlanTest { val unionQuery = testUnion.select('a) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Union(testRelation.select('a), testRelation2.select('d)).analyze + Unions(testRelation.select('a) :: testRelation2.select('d) :: Nil ).analyze comparePlans(unionOptimized, unionCorrectAnswer) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 183d9b65023b9..40437ea391bf6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -347,7 +347,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { LocalTableScan(output, data) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil - case Unions(unionChildren) => + case logical.Unions(unionChildren) => execution.Union(unionChildren.map(planLater)) :: Nil case logical.Except(left, right) => execution.Except(planLater(left), planLater(right)) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ad478b0511095..90f0d2418f325 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -25,7 +25,7 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation +import org.apache.spark.sql.catalyst.plans.logical.{Unions, OneRowRelation} import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.functions._ @@ -98,6 +98,19 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.collect().toSeq) } + test ("union all") { + val unionsDF = testData.unionAll(testData).unionAll(testData) + .unionAll(testData).unionAll(testData) + + assert(unionsDF.queryExecution.optimizedPlan.collect { + case j @ Unions(Seq(_, _, _, _, _)) => j }.size === 1) + + checkAnswer( + unionsDF.agg(avg('key), max('key), min('key), sum('key)), + Row(50.5, 100, 1, 25250) :: Nil + ) + } + test("empty data frame") { assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(sqlContext.emptyDataFrame.count() === 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 2fb439f50117a..8e29150e1b2b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -51,18 +51,6 @@ class PlannerSuite extends SharedSQLContext { s"The plan of query $query does not have partial aggregations.") } - test("unions are collapsed") { - val planner = sqlContext.planner - import planner._ - val query = testData.unionAll(testData).unionAll(testData).logicalPlan - val planned = BasicOperators(query).head - val logicalUnions = query collect { case u: logical.Union => u } - val physicalUnions = planned collect { case u: execution.Union => u } - - assert(logicalUnions.size === 2) - assert(physicalUnions.size === 1) - } - test("count is partially aggregated") { val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed testPartialAggregationPlan(query) From 5d031a76f223c65f617a1c438d41d27f445d74cc Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 4 Jan 2016 13:44:21 -0800 Subject: [PATCH 05/27] remove the old operator union --- .../sql/catalyst/optimizer/Optimizer.scala | 33 +++---------------- .../plans/logical/basicOperators.scala | 9 +++-- .../optimizer/SetOperationPushDownSuite.scala | 8 ++--- .../org/apache/spark/sql/DataFrame.scala | 2 +- 4 files changed, 14 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5538d5cffc7ae..315a84d45e5ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -98,8 +98,8 @@ object SamplePushDown extends Rule[LogicalPlan] { /** * Pushes certain operations to both sides of a Union, Intersect or Except operator. * Operations that are safe to pushdown are listed as follows. - * Union: - * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is + * Unions: + * Right now, Unions means UNION ALL, which does not de-duplicate rows. So, it is * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT, * we will not be able to pushdown Projections. * @@ -119,7 +119,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { * Maps Attributes from the left side to the corresponding Attribute on the right side. */ private def buildRewrites(bn: BinaryNode): AttributeMap[Attribute] = { - assert(bn.isInstanceOf[Union] || bn.isInstanceOf[Intersect] || bn.isInstanceOf[Except]) + (bn.isInstanceOf[Intersect] || bn.isInstanceOf[Except]) assert(bn.left.output.size == bn.right.output.size) AttributeMap(bn.left.output.zip(bn.right.output)) @@ -127,7 +127,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { /** * Rewrites an expression so that it can be pushed to the right side of a - * Union, Intersect or Except operator. This method relies on the fact that the output attributes + * Unions, Intersect or Except operator. This method relies on the fact that the output attributes * of a union/intersect/except are always equal to the left child's output. */ private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = { @@ -156,27 +156,6 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Push down filter into union - case Filter(condition, u @ Union(left, right)) => - val (deterministic, nondeterministic) = partitionByDeterministic(condition) - val rewrites = buildRewrites(u) - Filter(nondeterministic, - Union( - Filter(deterministic, left), - Filter(pushToRight(deterministic, rewrites), right) - ) - ) - - // Push down deterministic projection through UNION ALL - case p @ Project(projectList, u @ Union(left, right)) => - if (projectList.forall(_.deterministic)) { - val rewrites = buildRewrites(u) - Union( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) - } else { - p - } // Push down filter through INTERSECT case Filter(condition, i @ Intersect(left, right)) => @@ -598,17 +577,15 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { } /** - * Combines all adjacent [[Union]] and [[Unions]] operators into a single [[Unions]]. + * Combines all adjacent [[Unions]] operators into a single [[Unions]]. */ object CombineUnions extends Rule[LogicalPlan] { private def collectUnionChildren(plan: LogicalPlan): Seq[LogicalPlan] = plan match { - case Union(l, r) => collectUnionChildren(l) ++ collectUnionChildren(r) case Unions(children) => children.flatMap(collectUnionChildren) case other => other :: Nil } def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case u: Union => Unions(collectUnionChildren(u)) case u: Unions => Unions(collectUnionChildren(u)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 014a8fff9c1b3..3373c971bf0bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -114,11 +114,10 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le override def output: Seq[Attribute] = left.output } -case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { - - override def statistics: Statistics = { - val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes - Statistics(sizeInBytes = sizeInBytes) +/** Factory for constructing new `AppendColumn` nodes. */ +object Unions { + def apply(left: LogicalPlan, right: LogicalPlan): Unions = { + Unions (left :: right :: Nil) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index 65c8ec52ad478..b859c51dcee0e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -38,20 +38,20 @@ class SetOperationPushDownSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) - val testUnion = Union(testRelation, testRelation2) + val testUnion = Unions(testRelation, testRelation2) val testIntersect = Intersect(testRelation, testRelation2) val testExcept = Except(testRelation, testRelation2) test("union: combine unions into one unions") { - val unionQuery1 = Union(Union(testRelation, testRelation2), testRelation) - val unionQuery2 = Union(testRelation, Union(testRelation2, testRelation)) + val unionQuery1 = Unions(Unions(testRelation, testRelation2), testRelation) + val unionQuery2 = Unions(testRelation, Unions(testRelation2, testRelation)) val unionOptimized1 = Optimize.execute(unionQuery1.analyze) val unionOptimized2 = Optimize.execute(unionQuery2.analyze) comparePlans(unionOptimized1, unionOptimized2) val combinedUnions = Unions(unionOptimized1 :: unionOptimized2 :: Nil) val combinedUnionsOptimized = Optimize.execute(combinedUnions.analyze) - val unionQuery3 = Union(unionQuery1, unionQuery2) + val unionQuery3 = Unions(unionQuery1, unionQuery2) val unionOptimized3 = Optimize.execute(unionQuery3.analyze) comparePlans(combinedUnionsOptimized, unionOptimized3) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 965eaa9efec41..9f8e7bc8a2e40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1005,7 +1005,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def unionAll(other: DataFrame): DataFrame = withPlan { - Union(logicalPlan, other.logicalPlan) + Unions(logicalPlan, other.logicalPlan) } /** From c1f66f744fce35eb657f9ec8a971dbd5449d0985 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 5 Jan 2016 01:18:54 -0800 Subject: [PATCH 06/27] remove the old operator union #2. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 10 ++++ .../catalyst/analysis/HiveTypeCoercion.scala | 50 +++++++++++++++---- .../sql/catalyst/optimizer/Optimizer.scala | 14 +++--- .../plans/logical/basicOperators.scala | 10 +++- .../analysis/DecimalPrecisionSuite.scala | 6 +-- .../analysis/HiveTypeCoercionSuite.scala | 22 ++++---- .../optimizer/SetOperationPushDownSuite.scala | 27 ++++++---- .../org/apache/spark/sql/DataFrame.scala | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../org/apache/spark/sql/hive/HiveQl.scala | 7 ++- 10 files changed, 106 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a1be1473cc80b..c62ea6d3d16a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -189,6 +189,16 @@ trait CheckAnalysis { s"but the left table has ${left.output.length} columns and the right has " + s"${right.output.length}") + case s: Unions if s.children.exists(_.output.length != s.children.head.output.length) => + s.children.filter(_.output.length != s.children.head.output.length).foreach { child => + failAnalysis( + s""" + |Unions can only be performed on tables with the same number of columns, + | but the table '${child.simpleString}' has '${child.output.length}' columns + | and the first table '${s.children.head.simpleString}' has + | '${s.children.head.output.length}' columns""".stripMargin) + } + case _ => // Fallbacks to the following checks } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index dbcbd6854b474..d5428f7c16c25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -196,7 +196,7 @@ object HiveTypeCoercion { * - LongType to DoubleType * - DecimalType to Double * - * This rule is only applied to Union/Except/Intersect + * This rule is only applied to Unions/Except/Intersect */ object WidenSetOperationTypes extends Rule[LogicalPlan] { @@ -212,22 +212,47 @@ object HiveTypeCoercion { case other => None } - def castOutput(plan: LogicalPlan): LogicalPlan = { - val casted = plan.output.zip(castedTypes).map { - case (e, Some(dt)) if e.dataType != dt => - Alias(Cast(e, dt), e.name)() - case (e, _) => e - } - Project(casted, plan) + if (castedTypes.exists(_.isDefined)) { + (castOutput(left, castedTypes), castOutput(right, castedTypes)) + } else { + (left, right) } + } + + private[this] def widenOutputTypes( + planName: String, + children: Seq[LogicalPlan]): Seq[LogicalPlan] = { + require(children.forall(_.output.length == children.head.output.length)) + + val castedTypes: Seq[Option[DataType]] = + children.tail.foldLeft(children.head.output.map(a => Option(a.dataType))) { + case (currentOutputDataTypes, child) => { + currentOutputDataTypes.zip(child.output).map { + case (Some(dt), a2) if dt != a2.dataType => + findWiderTypeForTwo(dt, a2.dataType) + case other => None + } + } + } if (castedTypes.exists(_.isDefined)) { - (castOutput(left), castOutput(right)) + children.map(castOutput(_, castedTypes)) } else { - (left, right) + children } } + private[this] def castOutput( + plan: LogicalPlan, + castedTypes: Seq[Option[DataType]]): LogicalPlan = { + val casted = plan.output.zip(castedTypes).map { + case (e, Some(dt)) if e.dataType != dt => + Alias(Cast(e, dt), e.name)() + case (e, _) => e + } + Project(casted, plan) + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if p.analyzed => p @@ -235,6 +260,11 @@ object HiveTypeCoercion { && left.output.length == right.output.length && !s.resolved => val (newLeft, newRight) = widenOutputTypes(s.nodeName, left, right) s.makeCopy(Array(newLeft, newRight)) + + case s: Unions if s.childrenResolved && + s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => + val newChildren: Seq[LogicalPlan] = widenOutputTypes(s.nodeName, s.children) + s.makeCopy(Array(newChildren)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 315a84d45e5ba..1c3189653c26d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -119,7 +119,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { * Maps Attributes from the left side to the corresponding Attribute on the right side. */ private def buildRewrites(bn: BinaryNode): AttributeMap[Attribute] = { - (bn.isInstanceOf[Intersect] || bn.isInstanceOf[Except]) + assert(bn.isInstanceOf[Intersect] || bn.isInstanceOf[Except]) assert(bn.left.output.size == bn.right.output.size) AttributeMap(bn.left.output.zip(bn.right.output)) @@ -580,13 +580,15 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { * Combines all adjacent [[Unions]] operators into a single [[Unions]]. */ object CombineUnions extends Rule[LogicalPlan] { - private def collectUnionChildren(plan: LogicalPlan): Seq[LogicalPlan] = plan match { - case Unions(children) => children.flatMap(collectUnionChildren) - case other => other :: Nil + private def buildUnionChildren(children: Seq[LogicalPlan]): Seq[LogicalPlan] = + children.foldLeft(Seq.empty[LogicalPlan]) { (newChildren, child) => child match { + case Unions(grandchildren) => newChildren ++ grandchildren + case other => newChildren ++ Seq(other) + } } - def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case u: Unions => Unions(collectUnionChildren(u)) + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case u @ Unions(children) => Unions(buildUnionChildren(children)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 3373c971bf0bd..244e2fe3b3f48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -114,8 +114,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le override def output: Seq[Attribute] = left.output } -/** Factory for constructing new `AppendColumn` nodes. */ -object Unions { +/** Factory for constructing new `Unions` nodes. */ +object Union { def apply(left: LogicalPlan, right: LogicalPlan): Unions = { Unions (left :: right :: Nil) } @@ -131,6 +131,12 @@ case class Unions(children: Seq[LogicalPlan]) extends LogicalPlan { } } + override lazy val resolved: Boolean = + childrenResolved && + children.forall(_.output.length == children.head.output.length) && + children.forall(_.output.zip(children.head.output).forall { + case (l, r) => l.dataType == r.dataType }) + override def statistics: Statistics = { val sizeInBytes = children.map(_.statistics.sizeInBytes).sum Statistics(sizeInBytes = sizeInBytes) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index fed591fd90a9a..cdcf97461387c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -22,9 +22,9 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project, Union, Unions} +import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf} class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { val conf = new SimpleCatalystConf(true) @@ -70,7 +70,7 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { Union(Project(Seq(Alias(left, "l")()), relation), Project(Seq(Alias(right, "r")()), relation)) val (l, r) = analyzer.execute(plan).collect { - case Union(left, right) => (left.output.head, right.output.head) + case Unions(Seq(child1, child2)) => (child1.output.head, child2.output.head) }.head assert(l.dataType === expectedType) assert(r.dataType === expectedType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 142915056f451..439e9e74cc262 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -384,11 +384,11 @@ class HiveTypeCoercionSuite extends PlanTest { val wt = HiveTypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val r1 = wt(Union(left, right)).asInstanceOf[Union] + val r1 = wt(Union(left, right)).asInstanceOf[Unions] val r2 = wt(Except(left, right)).asInstanceOf[Except] val r3 = wt(Intersect(left, right)).asInstanceOf[Intersect] - checkOutput(r1.left, expectedTypes) - checkOutput(r1.right, expectedTypes) + checkOutput(r1.children.head, expectedTypes) + checkOutput(r1.children.last, expectedTypes) checkOutput(r2.left, expectedTypes) checkOutput(r2.right, expectedTypes) checkOutput(r3.left, expectedTypes) @@ -410,12 +410,12 @@ class HiveTypeCoercionSuite extends PlanTest { AttributeReference("r", DecimalType(5, 5))()) val expectedType1 = Seq(DecimalType(10, 8)) - val r1 = dp(Union(left1, right1)).asInstanceOf[Union] + val r1 = dp(Union(left1, right1)).asInstanceOf[Unions] val r2 = dp(Except(left1, right1)).asInstanceOf[Except] val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect] - checkOutput(r1.left, expectedType1) - checkOutput(r1.right, expectedType1) + checkOutput(r1.children.head, expectedType1) + checkOutput(r1.children.last, expectedType1) checkOutput(r2.left, expectedType1) checkOutput(r2.right, expectedType1) checkOutput(r3.left, expectedType1) @@ -427,23 +427,23 @@ class HiveTypeCoercionSuite extends PlanTest { val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5), DecimalType(25, 5), DoubleType, DoubleType) - rightTypes.zip(expectedTypes).map { case (rType, expectedType) => + rightTypes.zip(expectedTypes).foreach { case (rType, expectedType) => val plan2 = LocalRelation( AttributeReference("r", rType)()) - val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union] + val r1 = dp(Union(plan1, plan2)).asInstanceOf[Unions] val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except] val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect] - checkOutput(r1.right, Seq(expectedType)) + checkOutput(r1.children.last, Seq(expectedType)) checkOutput(r2.right, Seq(expectedType)) checkOutput(r3.right, Seq(expectedType)) - val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union] + val r4 = dp(Union(plan2, plan1)).asInstanceOf[Unions] val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except] val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect] - checkOutput(r4.left, Seq(expectedType)) + checkOutput(r4.children.last, Seq(expectedType)) checkOutput(r5.left, Seq(expectedType)) checkOutput(r6.left, Seq(expectedType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index b859c51dcee0e..fbcc1ad6a715b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -38,46 +38,51 @@ class SetOperationPushDownSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) - val testUnion = Unions(testRelation, testRelation2) + val testUnion = Union(testRelation, testRelation2) val testIntersect = Intersect(testRelation, testRelation2) val testExcept = Except(testRelation, testRelation2) test("union: combine unions into one unions") { - val unionQuery1 = Unions(Unions(testRelation, testRelation2), testRelation) - val unionQuery2 = Unions(testRelation, Unions(testRelation2, testRelation)) + val unionQuery1 = Union(Union(testRelation, testRelation2), testRelation) + val unionQuery2 = Union(testRelation, Union(testRelation2, testRelation)) val unionOptimized1 = Optimize.execute(unionQuery1.analyze) val unionOptimized2 = Optimize.execute(unionQuery2.analyze) + comparePlans(unionOptimized1, unionOptimized2) val combinedUnions = Unions(unionOptimized1 :: unionOptimized2 :: Nil) val combinedUnionsOptimized = Optimize.execute(combinedUnions.analyze) - val unionQuery3 = Unions(unionQuery1, unionQuery2) + val unionQuery3 = Union(unionQuery1, unionQuery2) val unionOptimized3 = Optimize.execute(unionQuery3.analyze) comparePlans(combinedUnionsOptimized, unionOptimized3) } - test("union/intersect/except: filter to each side") { - val unionQuery = testUnion.where('a === 1) + test("intersect/except: filter to each side") { val intersectQuery = testIntersect.where('b < 10) val exceptQuery = testExcept.where('c >= 5) - val unionOptimized = Optimize.execute(unionQuery.analyze) val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - val unionCorrectAnswer = - Unions(testRelation.where('a === 1) :: testRelation2.where('d === 1) :: Nil).analyze val intersectCorrectAnswer = Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze val exceptCorrectAnswer = Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze - comparePlans(unionOptimized, unionCorrectAnswer) comparePlans(intersectOptimized, intersectCorrectAnswer) comparePlans(exceptOptimized, exceptCorrectAnswer) } - test("union: project to each side") { + ignore("union: filter to each side") { + val unionQuery = testUnion.where('a === 1) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Unions(testRelation.where('a === 1) :: testRelation2.where('d === 1) :: Nil).analyze + + comparePlans(unionOptimized, unionCorrectAnswer) + } + + ignore("union: project to each side") { val unionQuery = testUnion.select('a) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 9f8e7bc8a2e40..965eaa9efec41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1005,7 +1005,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def unionAll(other: DataFrame): DataFrame = withPlan { - Unions(logicalPlan, other.logicalPlan) + Union(logicalPlan, other.logicalPlan) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a763a951440cc..0f2eaa2f8c9ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -608,7 +608,9 @@ class Dataset[T] private[sql]( * duplicate items. As such, it is analogous to `UNION ALL` in SQL. * @since 1.6.0 */ - def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union) + def union(other: Dataset[T]): Dataset[T] = withPlan[T](other){ (left, right) => + Unions(left :: right :: Nil) + } /** * Returns a new [[Dataset]] where any elements present in `other` have been removed. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index b1d841d1b5543..9c4a671614f47 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1211,7 +1211,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } // If there are multiple INSERTS just UNION them together into on query. - val query = queries.reduceLeft(Union) + val query = + if (queries.length == 1) { + queries.head + } else { + Unions(queries) + } // return With plan if there is CTE cteRelations.map(With(query, _)).getOrElse(query) From c1dcd02c32278c47764b93e4ba990fa562107f74 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 5 Jan 2016 08:28:40 -0800 Subject: [PATCH 07/27] rename. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../sql/catalyst/analysis/HiveTypeCoercion.scala | 4 ++-- .../spark/sql/catalyst/optimizer/Optimizer.scala | 14 +++++++------- .../catalyst/plans/logical/basicOperators.scala | 6 +++--- .../catalyst/analysis/DecimalPrecisionSuite.scala | 4 ++-- .../catalyst/analysis/HiveTypeCoercionSuite.scala | 8 ++++---- .../optimizer/SetOperationPushDownSuite.scala | 6 +++--- .../main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 10 +++++----- .../scala/org/apache/spark/sql/hive/HiveQl.scala | 2 +- 11 files changed, 30 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c62ea6d3d16a4..77b2d4eda1eef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -189,7 +189,7 @@ trait CheckAnalysis { s"but the left table has ${left.output.length} columns and the right has " + s"${right.output.length}") - case s: Unions if s.children.exists(_.output.length != s.children.head.output.length) => + case s: Union if s.children.exists(_.output.length != s.children.head.output.length) => s.children.filter(_.output.length != s.children.head.output.length).foreach { child => failAnalysis( s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index d5428f7c16c25..b9a482fcea0c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -196,7 +196,7 @@ object HiveTypeCoercion { * - LongType to DoubleType * - DecimalType to Double * - * This rule is only applied to Unions/Except/Intersect + * This rule is only applied to Union/Except/Intersect */ object WidenSetOperationTypes extends Rule[LogicalPlan] { @@ -261,7 +261,7 @@ object HiveTypeCoercion { val (newLeft, newRight) = widenOutputTypes(s.nodeName, left, right) s.makeCopy(Array(newLeft, newRight)) - case s: Unions if s.childrenResolved && + case s: Union if s.childrenResolved && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => val newChildren: Seq[LogicalPlan] = widenOutputTypes(s.nodeName, s.children) s.makeCopy(Array(newChildren)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1c3189653c26d..d1c7250bb5520 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -40,7 +40,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: - Batch("Unions", FixedPoint(100), + Batch("Union", FixedPoint(100), CombineUnions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down @@ -98,8 +98,8 @@ object SamplePushDown extends Rule[LogicalPlan] { /** * Pushes certain operations to both sides of a Union, Intersect or Except operator. * Operations that are safe to pushdown are listed as follows. - * Unions: - * Right now, Unions means UNION ALL, which does not de-duplicate rows. So, it is + * Union: + * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT, * we will not be able to pushdown Projections. * @@ -127,7 +127,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { /** * Rewrites an expression so that it can be pushed to the right side of a - * Unions, Intersect or Except operator. This method relies on the fact that the output attributes + * Union, Intersect or Except operator. This method relies on the fact that the output attributes * of a union/intersect/except are always equal to the left child's output. */ private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = { @@ -577,18 +577,18 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { } /** - * Combines all adjacent [[Unions]] operators into a single [[Unions]]. + * Combines all adjacent [[Union]] operators into a single [[Union]]. */ object CombineUnions extends Rule[LogicalPlan] { private def buildUnionChildren(children: Seq[LogicalPlan]): Seq[LogicalPlan] = children.foldLeft(Seq.empty[LogicalPlan]) { (newChildren, child) => child match { - case Unions(grandchildren) => newChildren ++ grandchildren + case Union(grandchildren) => newChildren ++ grandchildren case other => newChildren ++ Seq(other) } } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case u @ Unions(children) => Unions(buildUnionChildren(children)) + case u @ Union(children) => Union(buildUnionChildren(children)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 244e2fe3b3f48..68a6e7ee9bb80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -116,12 +116,12 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le /** Factory for constructing new `Unions` nodes. */ object Union { - def apply(left: LogicalPlan, right: LogicalPlan): Unions = { - Unions (left :: right :: Nil) + def apply(left: LogicalPlan, right: LogicalPlan): Union = { + Union (left :: right :: Nil) } } -case class Unions(children: Seq[LogicalPlan]) extends LogicalPlan { +case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { override def output: Seq[Attribute] = { children.tail.foldLeft(children.head.output) { case (currentOutput, child) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index cdcf97461387c..c41a57fa918a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project, Union, Unions} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project, Union} import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.types._ @@ -70,7 +70,7 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { Union(Project(Seq(Alias(left, "l")()), relation), Project(Seq(Alias(right, "r")()), relation)) val (l, r) = analyzer.execute(plan).collect { - case Unions(Seq(child1, child2)) => (child1.output.head, child2.output.head) + case Union(Seq(child1, child2)) => (child1.output.head, child2.output.head) }.head assert(l.dataType === expectedType) assert(r.dataType === expectedType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 439e9e74cc262..9d2f41ff5adca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -384,7 +384,7 @@ class HiveTypeCoercionSuite extends PlanTest { val wt = HiveTypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val r1 = wt(Union(left, right)).asInstanceOf[Unions] + val r1 = wt(Union(left, right)).asInstanceOf[Union] val r2 = wt(Except(left, right)).asInstanceOf[Except] val r3 = wt(Intersect(left, right)).asInstanceOf[Intersect] checkOutput(r1.children.head, expectedTypes) @@ -410,7 +410,7 @@ class HiveTypeCoercionSuite extends PlanTest { AttributeReference("r", DecimalType(5, 5))()) val expectedType1 = Seq(DecimalType(10, 8)) - val r1 = dp(Union(left1, right1)).asInstanceOf[Unions] + val r1 = dp(Union(left1, right1)).asInstanceOf[Union] val r2 = dp(Except(left1, right1)).asInstanceOf[Except] val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect] @@ -431,7 +431,7 @@ class HiveTypeCoercionSuite extends PlanTest { val plan2 = LocalRelation( AttributeReference("r", rType)()) - val r1 = dp(Union(plan1, plan2)).asInstanceOf[Unions] + val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union] val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except] val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect] @@ -439,7 +439,7 @@ class HiveTypeCoercionSuite extends PlanTest { checkOutput(r2.right, Seq(expectedType)) checkOutput(r3.right, Seq(expectedType)) - val r4 = dp(Union(plan2, plan1)).asInstanceOf[Unions] + val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union] val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except] val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index fbcc1ad6a715b..761d05830b5c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -50,7 +50,7 @@ class SetOperationPushDownSuite extends PlanTest { comparePlans(unionOptimized1, unionOptimized2) - val combinedUnions = Unions(unionOptimized1 :: unionOptimized2 :: Nil) + val combinedUnions = Union(unionOptimized1 :: unionOptimized2 :: Nil) val combinedUnionsOptimized = Optimize.execute(combinedUnions.analyze) val unionQuery3 = Union(unionQuery1, unionQuery2) val unionOptimized3 = Optimize.execute(unionQuery3.analyze) @@ -77,7 +77,7 @@ class SetOperationPushDownSuite extends PlanTest { val unionQuery = testUnion.where('a === 1) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Unions(testRelation.where('a === 1) :: testRelation2.where('d === 1) :: Nil).analyze + Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze comparePlans(unionOptimized, unionCorrectAnswer) } @@ -86,7 +86,7 @@ class SetOperationPushDownSuite extends PlanTest { val unionQuery = testUnion.select('a) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Unions(testRelation.select('a) :: testRelation2.select('d) :: Nil ).analyze + Union(testRelation.select('a), testRelation2.select('d)).analyze comparePlans(unionOptimized, unionCorrectAnswer) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0f2eaa2f8c9ca..c76836c6bc25e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -609,7 +609,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def union(other: Dataset[T]): Dataset[T] = withPlan[T](other){ (left, right) => - Unions(left :: right :: Nil) + Union(left, right) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 40437ea391bf6..e0060117684d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -347,7 +347,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { LocalTableScan(output, data) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil - case logical.Unions(unionChildren) => + case logical.Union(unionChildren) => execution.Union(unionChildren.map(planLater)) :: Nil case logical.Except(left, right) => execution.Except(planLater(left), planLater(right)) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 90f0d2418f325..29ee89f46c78a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -25,7 +25,7 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.plans.logical.{Unions, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Union, OneRowRelation} import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.functions._ @@ -99,14 +99,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test ("union all") { - val unionsDF = testData.unionAll(testData).unionAll(testData) + val unionDF = testData.unionAll(testData).unionAll(testData) .unionAll(testData).unionAll(testData) - assert(unionsDF.queryExecution.optimizedPlan.collect { - case j @ Unions(Seq(_, _, _, _, _)) => j }.size === 1) + assert(unionDF.queryExecution.optimizedPlan.collect { + case j @ Union(Seq(_, _, _, _, _)) => j }.size === 1) checkAnswer( - unionsDF.agg(avg('key), max('key), min('key), sum('key)), + unionDF.agg(avg('key), max('key), min('key), sum('key)), Row(50.5, 100, 1, 25250) :: Nil ) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 9c4a671614f47..ef993b13d95ad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1215,7 +1215,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C if (queries.length == 1) { queries.head } else { - Unions(queries) + Union(queries) } // return With plan if there is CTE From 51ad5b27da1733bccf1a978e6ec9bd43d0736896 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 5 Jan 2016 19:29:38 -0800 Subject: [PATCH 08/27] address the comments. --- .../catalyst/analysis/HiveTypeCoercion.scala | 7 +++--- .../sql/catalyst/optimizer/Optimizer.scala | 13 +++-------- .../sql/catalyst/planning/patterns.scala | 23 +++++++++++++++++++ .../analysis/HiveTypeCoercionSuite.scala | 16 +++++++++---- .../org/apache/spark/sql/DataFrame.scala | 3 ++- .../scala/org/apache/spark/sql/Dataset.scala | 7 +++--- .../org/apache/spark/sql/DataFrameSuite.scala | 3 ++- 7 files changed, 49 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index b9a482fcea0c2..b5f8d5506abfc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -225,11 +225,12 @@ object HiveTypeCoercion { require(children.forall(_.output.length == children.head.output.length)) val castedTypes: Seq[Option[DataType]] = - children.tail.foldLeft(children.head.output.map(a => Option(a.dataType))) { + children.foldLeft(children.head.output.map(a => Option(a.dataType))) { case (currentOutputDataTypes, child) => { currentOutputDataTypes.zip(child.output).map { - case (Some(dt), a2) if dt != a2.dataType => - findWiderTypeForTwo(dt, a2.dataType) + case (Some(dt), ar) if dt != ar.dataType => + findWiderTypeForTwo(dt, ar.dataType) + case (Some(dt), ar) if dt == ar.dataType => Option(dt) case other => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d1c7250bb5520..c5f6d48b5dcf0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -22,7 +22,7 @@ import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.planning.{Unions, ExtractFiltersAndInnerJoins} import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -580,15 +580,8 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { * Combines all adjacent [[Union]] operators into a single [[Union]]. */ object CombineUnions extends Rule[LogicalPlan] { - private def buildUnionChildren(children: Seq[LogicalPlan]): Seq[LogicalPlan] = - children.foldLeft(Seq.empty[LogicalPlan]) { (newChildren, child) => child match { - case Union(grandchildren) => newChildren ++ grandchildren - case other => newChildren ++ Seq(other) - } - } - - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case u @ Union(children) => Union(buildUnionChildren(children)) + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case u @ Unions(children) => Union(children) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 9be88ac1012b9..5d1f3ac4cad5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -22,6 +22,9 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import scala.annotation.tailrec +import scala.collection.immutable.Queue + /** * A pattern that matches any number of project or filter operations on top of another relational * operator. All filter operators are collected and their conditions are broken up and returned @@ -169,3 +172,23 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { case _ => None } } + + +/** + * A pattern that collects all adjacent unions and returns their children as a Seq. + */ +object Unions { + def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match { + case u: Union => Some(collectUnionChildren(u :: Nil, Seq.empty[LogicalPlan])) + case _ => None + } + + @tailrec + private def collectUnionChildren( + plan: List[LogicalPlan], + children: Seq[LogicalPlan]): Seq[LogicalPlan] = plan match { + case Nil => children + case Union(grandchildren) :: res => collectUnionChildren(grandchildren.toList ++ res, children) + case other :: res => collectUnionChildren(res, children :+ other) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 9d2f41ff5adca..d97dae7995581 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -370,24 +370,30 @@ class HiveTypeCoercionSuite extends PlanTest { } } - val left = LocalRelation( + val firstTable = LocalRelation( AttributeReference("i", IntegerType)(), AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), AttributeReference("b", ByteType)(), AttributeReference("d", DoubleType)()) - val right = LocalRelation( + val secondTable = LocalRelation( AttributeReference("s", StringType)(), AttributeReference("d", DecimalType(2, 1))(), AttributeReference("f", FloatType)(), AttributeReference("l", LongType)()) + val thirdTable = LocalRelation( + AttributeReference("m", StringType)(), + AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("p", ByteType)(), + AttributeReference("q", DoubleType)()) val wt = HiveTypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val r1 = wt(Union(left, right)).asInstanceOf[Union] - val r2 = wt(Except(left, right)).asInstanceOf[Except] - val r3 = wt(Intersect(left, right)).asInstanceOf[Intersect] + val r1 = wt(Union(firstTable :: secondTable :: thirdTable :: Nil)).asInstanceOf[Union] + val r2 = wt(Except(firstTable, secondTable)).asInstanceOf[Except] + val r3 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] checkOutput(r1.children.head, expectedTypes) + checkOutput(r1.children(1), expectedTypes) checkOutput(r1.children.last, expectedTypes) checkOutput(r2.left, expectedTypes) checkOutput(r2.right, expectedTypes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 965eaa9efec41..0878e3a0adc06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection, SqlParser} @@ -1005,7 +1006,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def unionAll(other: DataFrame): DataFrame = withPlan { - Union(logicalPlan, other.logicalPlan) + CombineUnions(Union(logicalPlan, other.logicalPlan)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c76836c6bc25e..d78e2fa542a1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ -import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ +import org.apache.spark.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias +import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} @@ -609,7 +610,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def union(other: Dataset[T]): Dataset[T] = withPlan[T](other){ (left, right) => - Union(left, right) + CombineUnions(Union(left, right)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 29ee89f46c78a..04e6ea736a36d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -102,7 +102,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val unionDF = testData.unionAll(testData).unionAll(testData) .unionAll(testData).unionAll(testData) - assert(unionDF.queryExecution.optimizedPlan.collect { + // Before optimizer, Union should be combined. + assert(unionDF.queryExecution.analyzed.collect { case j @ Union(Seq(_, _, _, _, _)) => j }.size === 1) checkAnswer( From 7a54c8f6c22187f1bf2202937bc623c57f665bb1 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 5 Jan 2016 22:11:00 -0800 Subject: [PATCH 09/27] Change the optimizer rule for pushing Filter and Project through new Union. --- .../sql/catalyst/optimizer/Optimizer.scala | 52 +++++++++++++++---- .../sql/catalyst/planning/patterns.scala | 1 - .../optimizer/SetOperationPushDownSuite.scala | 20 ++++--- .../apache/spark/sql/JavaDatasetSuite.java | 4 +- 4 files changed, 56 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c5f6d48b5dcf0..e1b6d5d178424 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -118,11 +118,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { /** * Maps Attributes from the left side to the corresponding Attribute on the right side. */ - private def buildRewrites(bn: BinaryNode): AttributeMap[Attribute] = { - assert(bn.isInstanceOf[Intersect] || bn.isInstanceOf[Except]) - assert(bn.left.output.size == bn.right.output.size) - - AttributeMap(bn.left.output.zip(bn.right.output)) + private def buildRewrites(left: LogicalPlan, right: LogicalPlan): AttributeMap[Attribute] = { + assert(left.output.size == right.output.size) + AttributeMap(left.output.zip(right.output)) } /** @@ -157,10 +155,44 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Push down deterministic projection through UNION ALL + case p @ Project(projectList, Union(children)) => + assert(children.nonEmpty) + if (projectList.forall(_.deterministic)) { + val newFirstChild = Project(projectList, children.head) + val newOtherChildren = children.tail.map { + child => { + val rewrites = buildRewrites(children.head, child) + Project(projectList.map(pushToRight(_, rewrites)), child) + } + } + Union( + newFirstChild +: newOtherChildren) + } else { + p + } + + // Push down filter into union + case Filter(condition, Union(children)) => + assert(children.nonEmpty) + val (deterministic, nondeterministic) = partitionByDeterministic(condition) + val newFirstChild = Filter(deterministic, children.head) + val newOtherChildren = children.tail.map { + child => { + val rewrites = buildRewrites(children.head, child) + Filter(pushToRight(deterministic, rewrites), child) + } + } + Filter(nondeterministic, + Union( + newFirstChild +: newOtherChildren + ) + ) + // Push down filter through INTERSECT - case Filter(condition, i @ Intersect(left, right)) => + case Filter(condition, Intersect(left, right)) => val (deterministic, nondeterministic) = partitionByDeterministic(condition) - val rewrites = buildRewrites(i) + val rewrites = buildRewrites(left, right) Filter(nondeterministic, Intersect( Filter(deterministic, left), @@ -169,9 +201,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { ) // Push down filter through EXCEPT - case Filter(condition, e @ Except(left, right)) => + case Filter(condition, Except(left, right)) => val (deterministic, nondeterministic) = partitionByDeterministic(condition) - val rewrites = buildRewrites(e) + val rewrites = buildRewrites(left, right) Filter(nondeterministic, Except( Filter(deterministic, left), @@ -581,7 +613,7 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { */ object CombineUnions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case u @ Unions(children) => Union(children) + case Unions(children) => Union(children) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 5d1f3ac4cad5f..f77bce368e34c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import scala.annotation.tailrec -import scala.collection.immutable.Queue /** * A pattern that matches any number of project or filter operations on top of another relational diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index c0b5c097848d9..3e8bb662625bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -30,15 +30,15 @@ class SetOperationPushDownSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubQueries) :: Batch("Union Pushdown", Once, + CombineUnions, SetOperationPushDown, - SimplifyFilters) :: - Batch("Unions", Once, - CombineUnions) :: Nil + SimplifyFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) - val testUnion = Union(testRelation, testRelation2) + val testRelation3 = LocalRelation('g.int, 'h.int, 'i.int) + val testUnion = Union(testRelation :: testRelation2 :: testRelation3 :: Nil) val testIntersect = Intersect(testRelation, testRelation2) val testExcept = Except(testRelation, testRelation2) @@ -73,20 +73,24 @@ class SetOperationPushDownSuite extends PlanTest { comparePlans(exceptOptimized, exceptCorrectAnswer) } - ignore("union: filter to each side") { + test("union: filter to each side") { val unionQuery = testUnion.where('a === 1) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze + Union(testRelation.where('a === 1) :: + testRelation2.where('d === 1) :: + testRelation3.where('g === 1) :: Nil).analyze comparePlans(unionOptimized, unionCorrectAnswer) } - ignore("union: project to each side") { + test("union: project to each side") { val unionQuery = testUnion.select('a) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Union(testRelation.select('a), testRelation2.select('d)).analyze + Union(testRelation.select('a) :: + testRelation2.select('d) :: + testRelation3.select('g) :: Nil).analyze comparePlans(unionOptimized, unionCorrectAnswer) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 9f8db39e33d7e..5689e841daa74 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -300,9 +300,9 @@ public void testSetOperation() { Dataset intersected = ds.intersect(ds2); Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); - Dataset unioned = ds.union(ds2); + Dataset unioned = ds.union(ds2).union(ds); Assert.assertEquals( - Arrays.asList("abc", "abc", "foo", "foo", "xyz", "xyz"), + Arrays.asList("abc", "abc", "abc", "abc", "foo", "foo", "xyz", "xyz", "xyz"), sort(unioned.collectAsList().toArray(new String[0]))); Dataset subtracted = ds.subtract(ds2); From 95e234901c91cde3e5c3820055bcc0d913e80e7d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 6 Jan 2016 11:03:57 -0800 Subject: [PATCH 10/27] refactored WidenSetOperationTypes and added test cases --- .../catalyst/analysis/HiveTypeCoercion.scala | 76 ++++++++----------- .../analysis/HiveTypeCoercionSuite.scala | 74 ++++++++++++++---- 2 files changed, 92 insertions(+), 58 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index b5f8d5506abfc..deabb2cb858e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types._ /** - * A collection of [[Rule Rules]] that can be used to coerce differing types that + * A collection of [[Rule]] that can be used to coerce differing types that * participate in operations into compatible ones. Most of these rules are based on Hive semantics, * but they do not introduce any dependencies on the hive codebase. For this reason they remain in * Catalyst until we have a more standard set of coercions. @@ -200,50 +200,28 @@ object HiveTypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - private[this] def widenOutputTypes( - planName: String, - left: LogicalPlan, - right: LogicalPlan): (LogicalPlan, LogicalPlan) = { - require(left.output.length == right.output.length) - - val castedTypes = left.output.zip(right.output).map { - case (lhs, rhs) if lhs.dataType != rhs.dataType => - findWiderTypeForTwo(lhs.dataType, rhs.dataType) - case other => None - } - - if (castedTypes.exists(_.isDefined)) { - (castOutput(left, castedTypes), castOutput(right, castedTypes)) - } else { - (left, right) - } - } - - private[this] def widenOutputTypes( + private def widenOutputTypes( planName: String, children: Seq[LogicalPlan]): Seq[LogicalPlan] = { require(children.forall(_.output.length == children.head.output.length)) - val castedTypes: Seq[Option[DataType]] = - children.foldLeft(children.head.output.map(a => Option(a.dataType))) { - case (currentOutputDataTypes, child) => { - currentOutputDataTypes.zip(child.output).map { - case (Some(dt), ar) if dt != ar.dataType => - findWiderTypeForTwo(dt, ar.dataType) - case (Some(dt), ar) if dt == ar.dataType => Option(dt) - case other => None - } - } + // Get a sequence of data types, each of which is the widest type of this specific attribute + // in all the children + val castedTypes: Seq[Option[DataType]] = { + val initialTypeSeq = children.head.output.map(a => Option(a.dataType)) + children.tail.foldLeft(initialTypeSeq) { (currentOutputDataTypes, child) => + // Find the wider type if the data type of this child do not match with + // the casted data types of the already processed children + getCastedTypes(currentOutputDataTypes, child.output) } - - if (castedTypes.exists(_.isDefined)) { - children.map(castOutput(_, castedTypes)) - } else { - children } + + // Add extra Project for type promotion if necessary + children.map(castOutput(_, castedTypes)) } - private[this] def castOutput( + // Add Project if the data types do not match + private def castOutput( plan: LogicalPlan, castedTypes: Seq[Option[DataType]]): LogicalPlan = { val casted = plan.output.zip(castedTypes).map { @@ -251,19 +229,31 @@ object HiveTypeCoercion { Alias(Cast(e, dt), e.name)() case (e, _) => e } - Project(casted, plan) + if (casted.exists(_.isInstanceOf[Alias])) Project(casted, plan) else plan + } + + private def getCastedTypes( + typeSeq: Seq[Option[DataType]], + attrSeq: Seq[Attribute]): Seq[Option[DataType]] = { + typeSeq.zip(attrSeq).map { + case (Some(dt), ar) if dt != ar.dataType => + findWiderTypeForTwo(dt, ar.dataType) + case (Some(dt), ar) if dt == ar.dataType => Option(dt) + case other => None + } } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if p.analyzed => p - case s @ SetOperation(left, right) if s.childrenResolved - && left.output.length == right.output.length && !s.resolved => - val (newLeft, newRight) = widenOutputTypes(s.nodeName, left, right) - s.makeCopy(Array(newLeft, newRight)) + case s @ SetOperation(left, right) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => + val newChildren: Seq[LogicalPlan] = widenOutputTypes(s.nodeName, left :: right :: Nil) + assert(newChildren.length == 2) + s.makeCopy(Array(newChildren.head, newChildren.last)) case s: Union if s.childrenResolved && - s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => + s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => val newChildren: Seq[LogicalPlan] = widenOutputTypes(s.nodeName, s.children) s.makeCopy(Array(newChildren)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 5ee556345232d..a99c86f001f35 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -362,13 +362,50 @@ class HiveTypeCoercionSuite extends PlanTest { ) } - test("WidenSetOperationTypes for union except and intersect") { - def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { - logical.output.zip(expectTypes).foreach { case (attr, dt) => - assert(attr.dataType === dt) - } + private def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { + logical.output.zip(expectTypes).foreach { case (attr, dt) => + assert(attr.dataType === dt) } + } + test("WidenSetOperationTypes for except and intersect") { + val firstTable = LocalRelation( + AttributeReference("i", IntegerType)(), + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("b", ByteType)(), + AttributeReference("d", DoubleType)()) + val secondTable = LocalRelation( + AttributeReference("s", StringType)(), + AttributeReference("d", DecimalType(2, 1))(), + AttributeReference("f", FloatType)(), + AttributeReference("l", LongType)()) + + val wt = HiveTypeCoercion.WidenSetOperationTypes + val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) + + val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except] + val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] + checkOutput(r1.left, expectedTypes) + checkOutput(r1.right, expectedTypes) + checkOutput(r2.left, expectedTypes) + checkOutput(r2.right, expectedTypes) + + // Check if a Project is added + assert(r1.left.isInstanceOf[Project]) + assert(r1.right.isInstanceOf[Project]) + assert(r2.left.isInstanceOf[Project]) + assert(r2.right.isInstanceOf[Project]) + + val r3 = wt(Except(firstTable, firstTable)).asInstanceOf[Except] + checkOutput(r3.left, Seq(IntegerType, DecimalType.SYSTEM_DEFAULT, ByteType, DoubleType)) + checkOutput(r3.right, Seq(IntegerType, DecimalType.SYSTEM_DEFAULT, ByteType, DoubleType)) + + // Check if no Project is added + assert(r3.left.isInstanceOf[LocalRelation]) + assert(r3.right.isInstanceOf[LocalRelation]) + } + + test("WidenSetOperationTypes for union") { val firstTable = LocalRelation( AttributeReference("i", IntegerType)(), AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), @@ -380,6 +417,11 @@ class HiveTypeCoercionSuite extends PlanTest { AttributeReference("f", FloatType)(), AttributeReference("l", LongType)()) val thirdTable = LocalRelation( + AttributeReference("m", StringType)(), + AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("p", FloatType)(), + AttributeReference("q", DoubleType)()) + val forthTable = LocalRelation( AttributeReference("m", StringType)(), AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(), AttributeReference("p", ByteType)(), @@ -388,16 +430,18 @@ class HiveTypeCoercionSuite extends PlanTest { val wt = HiveTypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val r1 = wt(Union(firstTable :: secondTable :: thirdTable :: Nil)).asInstanceOf[Union] - val r2 = wt(Except(firstTable, secondTable)).asInstanceOf[Except] - val r3 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] - checkOutput(r1.children.head, expectedTypes) - checkOutput(r1.children(1), expectedTypes) - checkOutput(r1.children.last, expectedTypes) - checkOutput(r2.left, expectedTypes) - checkOutput(r2.right, expectedTypes) - checkOutput(r3.left, expectedTypes) - checkOutput(r3.right, expectedTypes) + val unionRelation = wt( + Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union] + assert(unionRelation.children.length == 4) + checkOutput(unionRelation.children.head, expectedTypes) + checkOutput(unionRelation.children(1), expectedTypes) + checkOutput(unionRelation.children(2), expectedTypes) + checkOutput(unionRelation.children(3), expectedTypes) + + assert(unionRelation.children.head.isInstanceOf[Project]) + assert(unionRelation.children(1).isInstanceOf[Project]) + assert(unionRelation.children(2).isInstanceOf[LocalRelation]) + assert(unionRelation.children(3).isInstanceOf[Project]) } test("Transform Decimal precision/scale for union except and intersect") { From 6a6003e9fb97e610548295f3dc23cc23373f1e78 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 6 Jan 2016 11:04:14 -0800 Subject: [PATCH 11/27] addressed comments. --- .../sql/catalyst/optimizer/Optimizer.scala | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e1b6d5d178424..8696a12ee30dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -37,11 +37,15 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: + // - Do the first call of CombineUnions before starting the major Optimizer rules, + // since they could add/move extra operators between two adjacent Union operators. + // - Call CombineUnions again in Batch("Operator Optimizations"), + // since the other rules might make two separate Unions operators adjacent. + Batch("Union", FixedPoint(100), + CombineUnions) :: Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: - Batch("Union", FixedPoint(100), - CombineUnions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down SetOperationPushDown, @@ -160,14 +164,11 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { assert(children.nonEmpty) if (projectList.forall(_.deterministic)) { val newFirstChild = Project(projectList, children.head) - val newOtherChildren = children.tail.map { - child => { - val rewrites = buildRewrites(children.head, child) - Project(projectList.map(pushToRight(_, rewrites)), child) - } - } - Union( - newFirstChild +: newOtherChildren) + val newOtherChildren = children.tail.map ( child => { + val rewrites = buildRewrites(children.head, child) + Project(projectList.map(pushToRight(_, rewrites)), child) + } ) + Union(newFirstChild +: newOtherChildren) } else { p } @@ -183,11 +184,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { Filter(pushToRight(deterministic, rewrites), child) } } - Filter(nondeterministic, - Union( - newFirstChild +: newOtherChildren - ) - ) + Filter(nondeterministic, Union(newFirstChild +: newOtherChildren)) // Push down filter through INTERSECT case Filter(condition, Intersect(left, right)) => From 15ec058e00cd40140674cedb75802bc953dde26c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 6 Jan 2016 11:19:27 -0800 Subject: [PATCH 12/27] replace list by arrayBuffer in combineUnions --- .../spark/sql/catalyst/planning/patterns.scala | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index f77bce368e34c..106eaa764c11c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import scala.annotation.tailrec +import scala.collection.mutable.ArrayBuffer /** * A pattern that matches any number of project or filter operations on top of another relational @@ -178,16 +179,21 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { */ object Unions { def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match { - case u: Union => Some(collectUnionChildren(u :: Nil, Seq.empty[LogicalPlan])) + case u: Union => Some(collectUnionChildren(ArrayBuffer(u), Seq.empty[LogicalPlan])) case _ => None } @tailrec private def collectUnionChildren( - plan: List[LogicalPlan], - children: Seq[LogicalPlan]): Seq[LogicalPlan] = plan match { - case Nil => children - case Union(grandchildren) :: res => collectUnionChildren(grandchildren.toList ++ res, children) - case other :: res => collectUnionChildren(res, children :+ other) + plan: ArrayBuffer[LogicalPlan], + children: Seq[LogicalPlan]): Seq[LogicalPlan] = { + if (plan.isEmpty) children + else { + plan.head match { + case Union(grandchildren) => + collectUnionChildren(grandchildren.to[ArrayBuffer] ++ plan.tail, children) + case other => collectUnionChildren(plan.tail, children :+ other) + } + } } } From 5e06647708a3e47dab4be5c0074fb9712cbe3cb6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 6 Jan 2016 11:51:27 -0800 Subject: [PATCH 13/27] address comments. --- .../plans/logical/basicOperators.scala | 20 +++++++++++++------ .../org/apache/spark/sql/hive/HiveQl.scala | 9 ++------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 7d3d9a244f380..c3db896083404 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -115,7 +115,7 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le override def output: Seq[Attribute] = left.output } -/** Factory for constructing new `Unions` nodes. */ +/** Factory for constructing new `Union` nodes. */ object Union { def apply(left: LogicalPlan, right: LogicalPlan): Union = { Union (left :: right :: Nil) @@ -125,6 +125,7 @@ object Union { case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { override def output: Seq[Attribute] = { + // updating nullability to make all the children consistent children.tail.foldLeft(children.head.output) { case (currentOutput, child) => currentOutput.zip(child.output).map { case (a1, a2) => a1.withNullability(a1.nullable || a2.nullable) @@ -132,11 +133,18 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { } } - override lazy val resolved: Boolean = - childrenResolved && - children.forall(_.output.length == children.head.output.length) && - children.forall(_.output.zip(children.head.output).forall { - case (l, r) => l.dataType == r.dataType }) + override lazy val resolved: Boolean = { + val allChildrenCompatible: Boolean = + children.tail.forall( child => + // compare the attribute number with the first child + child.output.length == children.head.output.length && + // compare the data types with the first child + child.output.zip(children.head.output).forall { + case (l, r) => l.dataType == r.dataType } + ) + + childrenResolved && allChildrenCompatible + } override def statistics: Statistics = { val sizeInBytes = children.map(_.statistics.sizeInBytes).sum diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 51afa45120f98..8cd8e61b6c386 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1211,13 +1211,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C overwrite) } - // If there are multiple INSERTS just UNION them together into on query. - val query = - if (queries.length == 1) { - queries.head - } else { - Union(queries) - } + // If there are multiple INSERTS just UNION them together into one query. + val query = if (queries.length == 1) queries.head else Union(queries) // return With plan if there is CTE cteRelations.map(With(query, _)).getOrElse(query) From 2229932d18d9debe41180b2e3018a7ff8f2a4fc0 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 6 Jan 2016 12:07:46 -0800 Subject: [PATCH 14/27] move changes in HiveQI.scala to CatalystQI.scala --- .../main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 42bdf25b61ea5..fd1b96d136031 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -393,8 +393,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C overwrite) } - // If there are multiple INSERTS just UNION them together into on query. - val query = queries.reduceLeft(Union) + // If there are multiple INSERTS just UNION them together into one query. + val query = if (queries.length == 1) queries.head else Union(queries) // return With plan if there is CTE cteRelations.map(With(query, _)).getOrElse(query) From b3327b1bea5c60d8b0e94354692b94037976d350 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 6 Jan 2016 13:22:11 -0800 Subject: [PATCH 15/27] add lazy. --- .../spark/sql/catalyst/plans/logical/basicOperators.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index c3db896083404..e7ab5fb6e66fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -134,7 +134,7 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { } override lazy val resolved: Boolean = { - val allChildrenCompatible: Boolean = + lazy val allChildrenCompatible: Boolean = children.tail.forall( child => // compare the attribute number with the first child child.output.length == children.head.output.length && From 723c0da5036aaad109bfafd523e6e9762e6c0c21 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 7 Jan 2016 09:21:54 -0800 Subject: [PATCH 16/27] resolve comments. --- .../spark/sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../spark/sql/catalyst/optimizer/Optimizer.scala | 3 ++- .../sql/catalyst/plans/logical/basicOperators.scala | 12 ++++-------- .../main/scala/org/apache/spark/sql/DataFrame.scala | 2 ++ .../main/scala/org/apache/spark/sql/Dataset.scala | 2 ++ .../apache/spark/sql/execution/basicOperators.scala | 11 ++++------- 6 files changed, 15 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 22705ef81c765..711420083c84f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -190,7 +190,7 @@ trait CheckAnalysis { s"${right.output.length}") case s: Union if s.children.exists(_.output.length != s.children.head.output.length) => - s.children.filter(_.output.length != s.children.head.output.length).foreach { child => + s.children.filter(_.output.length != s.children.head.output.length).exists { child => failAnalysis( s""" |Unions can only be performed on tables with the same number of columns, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8696a12ee30dd..21addea40e5d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -38,7 +38,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: // - Do the first call of CombineUnions before starting the major Optimizer rules, - // since they could add/move extra operators between two adjacent Union operators. + // since it can reduce the number of iteration and the other rules could add/move + // extra operators between two adjacent Union operators. // - Call CombineUnions again in Batch("Operator Optimizations"), // since the other rules might make two separate Unions operators adjacent. Batch("Union", FixedPoint(100), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e7ab5fb6e66fa..7b7c57ab97c9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -124,14 +124,10 @@ object Union { case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { - override def output: Seq[Attribute] = { - // updating nullability to make all the children consistent - children.tail.foldLeft(children.head.output) { case (currentOutput, child) => - currentOutput.zip(child.output).map { case (a1, a2) => - a1.withNullability(a1.nullable || a2.nullable) - } - } - } + // updating nullability to make all the children consistent + override def output: Seq[Attribute] = + children.map(_.output).transpose.map(attrs => + attrs.head.withNullability(attrs.exists(_.nullable))) override lazy val resolved: Boolean = { lazy val allChildrenCompatible: Boolean = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index cb65d9f87af3f..55e737a9d6d9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1006,6 +1006,8 @@ class DataFrame private[sql]( * @since 1.3.0 */ def unionAll(other: DataFrame): DataFrame = withPlan { + // This breaks caching, but it's usually ok because it addresses a very specific use case: + // using union to union many files or partitions. CombineUnions(Union(logicalPlan, other.logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d78e2fa542a1e..8fbf5d3fbf9fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -610,6 +610,8 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def union(other: Dataset[T]): Dataset[T] = withPlan[T](other){ (left, right) => + // This breaks caching, but it's usually ok because it addresses a very specific use case: + // using union to union many files or partitions. CombineUnions(Union(left, right)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 95bef683238a7..7dcff80262e08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -175,13 +175,10 @@ case class Range( * Union two plans, without a distinct. This is UNION ALL in SQL. */ case class Union(children: Seq[SparkPlan]) extends SparkPlan { - override def output: Seq[Attribute] = { - children.tail.foldLeft(children.head.output) { case (currentOutput, child) => - currentOutput.zip(child.output).map { case (a1, a2) => - a1.withNullability(a1.nullable || a2.nullable) - } - } - } + override def output: Seq[Attribute] = + children.map(_.output).transpose.map(attrs => + attrs.head.withNullability(attrs.exists(_.nullable))) + protected override def doExecute(): RDD[InternalRow] = sparkContext.union(children.map(_.execute())) } From 42b81a8b32dc1a3347c77171f347ddeef7bfad51 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 7 Jan 2016 21:30:52 -0800 Subject: [PATCH 17/27] resolve comments. --- .../org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 3 +-- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 711420083c84f..77d7dc91dccb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -194,8 +194,7 @@ trait CheckAnalysis { failAnalysis( s""" |Unions can only be performed on tables with the same number of columns, - | but the table '${child.simpleString}' has '${child.output.length}' columns - | and the first table '${s.children.head.simpleString}' has + | but one table has '${child.output.length}' columns and another table has | '${s.children.head.output.length}' columns""".stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 21addea40e5d4..8240309cc4ade 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -42,7 +42,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { // extra operators between two adjacent Union operators. // - Call CombineUnions again in Batch("Operator Optimizations"), // since the other rules might make two separate Unions operators adjacent. - Batch("Union", FixedPoint(100), + Batch("Union", Once, CombineUnions) :: Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, From ab732c140664ce4b22d0e98b00aaf5c812875089 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 8 Jan 2016 09:28:31 -0800 Subject: [PATCH 18/27] Remove the unneeded parm. --- .../spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index deabb2cb858e5..e689c97a1e81f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -200,9 +200,7 @@ object HiveTypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - private def widenOutputTypes( - planName: String, - children: Seq[LogicalPlan]): Seq[LogicalPlan] = { + private def widenOutputTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { require(children.forall(_.output.length == children.head.output.length)) // Get a sequence of data types, each of which is the widest type of this specific attribute @@ -248,13 +246,13 @@ object HiveTypeCoercion { case s @ SetOperation(left, right) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = widenOutputTypes(s.nodeName, left :: right :: Nil) + val newChildren: Seq[LogicalPlan] = widenOutputTypes(left :: right :: Nil) assert(newChildren.length == 2) s.makeCopy(Array(newChildren.head, newChildren.last)) case s: Union if s.childrenResolved && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => - val newChildren: Seq[LogicalPlan] = widenOutputTypes(s.nodeName, s.children) + val newChildren: Seq[LogicalPlan] = widenOutputTypes(s.children) s.makeCopy(Array(newChildren)) } } From 031a5d8c85c8cd29aaed7411d915eaade57688ea Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 9 Jan 2016 00:17:03 -0800 Subject: [PATCH 19/27] changed the implementation of Union in sql generation --- .../scala/org/apache/spark/sql/hive/SQLBuilder.scala | 11 ++++++----- .../apache/spark/sql/hive/LogicalPlanToSQLSuite.scala | 4 ++++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 1c910051faccf..0f70d31e7d0fb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -129,11 +129,12 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi conditionSQL = condition.sql } yield s"$childSQL $whereOrHaving $conditionSQL" - case Union(left, right) => - for { - leftSQL <- toSQL(left) - rightSQL <- toSQL(right) - } yield s"$leftSQL UNION ALL $rightSQL" + case Union(children) if children.length > 1 => + val unionStmt: StringBuffer = new StringBuffer(s"${toSQL(children.head).getOrElse("")}") + children.tail.map{ case child => + unionStmt.append(s" UNION ALL ${toSQL(child).getOrElse("")}") + } + Some(s"${unionStmt.toString}") // ParquetRelation converted from Hive metastore table case Subquery(alias, LogicalRelation(r: ParquetRelation, _)) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 0e81acf532a03..5a3d652c4c362 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -102,6 +102,10 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl("SELECT id FROM t0 UNION ALL SELECT CAST(id AS INT) AS id FROM t0") } + test("three-child union") { + checkHiveQl("SELECT id FROM t0 UNION ALL SELECT id FROM t0 UNION ALL SELECT id FROM t0") + } + test("case") { checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM t0") } From f3d23dc452afa5d6848cf4469751f33cf3036f57 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 9 Jan 2016 13:17:27 -0800 Subject: [PATCH 20/27] fixed the implementation of Union in sql generation --- .../scala/org/apache/spark/sql/hive/SQLBuilder.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 0f70d31e7d0fb..4823e74e20f1c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -129,12 +129,14 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi conditionSQL = condition.sql } yield s"$childSQL $whereOrHaving $conditionSQL" + case Union(children) if children.length == 1 => + toSQL(children.head) + case Union(children) if children.length > 1 => - val unionStmt: StringBuffer = new StringBuffer(s"${toSQL(children.head).getOrElse("")}") - children.tail.map{ case child => - unionStmt.append(s" UNION ALL ${toSQL(child).getOrElse("")}") - } - Some(s"${unionStmt.toString}") + for { + leftSQL <- toSQL(children.head) + rightSQL <- toSQL(Union(children.tail)) + } yield s"$leftSQL UNION ALL $rightSQL" // ParquetRelation converted from Hive metastore table case Subquery(alias, LogicalRelation(r: ParquetRelation, _)) => From abfcf93cd6217c9cca1b4f9d0ac3b8b2a189563e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 13 Jan 2016 22:59:48 -0800 Subject: [PATCH 21/27] address comments. --- .../sql/catalyst/analysis/Analyzer.scala | 12 ++++- .../sql/catalyst/analysis/CheckAnalysis.scala | 13 +++-- .../catalyst/analysis/HiveTypeCoercion.scala | 51 ++++++++++++------- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../sql/catalyst/planning/patterns.scala | 19 +++---- .../plans/logical/basicOperators.scala | 3 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 6 +++ .../analysis/DecimalPrecisionSuite.scala | 1 - ...ownSuite.scala => SetOperationSuite.scala} | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 4 +- 10 files changed, 71 insertions(+), 42 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{SetOperationPushDownSuite.scala => SetOperationSuite.scala} (98%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8a33af8207350..4ee26450e918a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -66,7 +66,8 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, CTESubstitution, - WindowsSubstitution), + WindowsSubstitution, + EliminateUnions), Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: @@ -1170,6 +1171,15 @@ object EliminateSubQueries extends Rule[LogicalPlan] { } } +/** + * Removes [[Union]] operators from the plan if it just has one child. + */ +object EliminateUnions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Union(children) if children.size == 1 => children.head + } +} + /** * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level * expression in Project(project list) or Aggregate(aggregate expressions) or diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 77d7dc91dccb9..f2e78d97442e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -190,13 +190,12 @@ trait CheckAnalysis { s"${right.output.length}") case s: Union if s.children.exists(_.output.length != s.children.head.output.length) => - s.children.filter(_.output.length != s.children.head.output.length).exists { child => - failAnalysis( - s""" - |Unions can only be performed on tables with the same number of columns, - | but one table has '${child.output.length}' columns and another table has - | '${s.children.head.output.length}' columns""".stripMargin) - } + val firstError = s.children.find(_.output.length != s.children.head.output.length).get + failAnalysis( + s""" + |Unions can only be performed on tables with the same number of columns, + | but one table has '${firstError.output.length}' columns and another table has + | '${s.children.head.output.length}' columns""".stripMargin) case _ => // Fallbacks to the following checks } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 5db4afba6ef54..061991be84ed4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable +import scala.annotation.tailrec +import scala.collection.mutable + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -205,39 +208,49 @@ object HiveTypeCoercion { // Get a sequence of data types, each of which is the widest type of this specific attribute // in all the children - val castedTypes: Seq[Option[DataType]] = { - val initialTypeSeq = children.head.output.map(a => Option(a.dataType)) - children.tail.foldLeft(initialTypeSeq) { (currentOutputDataTypes, child) => - // Find the wider type if the data type of this child do not match with - // the casted data types of the already processed children - getCastedTypes(currentOutputDataTypes, child.output) - } - } + val castedTypes: Seq[DataType] = + getCastedTypes(children, attrIndex = 0, mutable.Queue[DataType]()) // Add extra Project for type promotion if necessary - children.map(castOutput(_, castedTypes)) + if (castedTypes.isEmpty) children else children.map(castOutput(_, castedTypes)) } // Add Project if the data types do not match private def castOutput( plan: LogicalPlan, - castedTypes: Seq[Option[DataType]]): LogicalPlan = { + castedTypes: Seq[DataType]): LogicalPlan = { val casted = plan.output.zip(castedTypes).map { - case (e, Some(dt)) if e.dataType != dt => + case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() case (e, _) => e } if (casted.exists(_.isInstanceOf[Alias])) Project(casted, plan) else plan } - private def getCastedTypes( - typeSeq: Seq[Option[DataType]], - attrSeq: Seq[Attribute]): Seq[Option[DataType]] = { - typeSeq.zip(attrSeq).map { - case (Some(dt), ar) if dt != ar.dataType => - findWiderTypeForTwo(dt, ar.dataType) - case (Some(dt), ar) if dt == ar.dataType => Option(dt) - case other => None + // Get the widest type for each attribute in all the children + @tailrec private def getCastedTypes( + children: Seq[LogicalPlan], + attrIndex: Int, + castedTypes: mutable.Queue[DataType]): Seq[DataType] = { + // Return the result after the widen data types have been found for all the children + if (attrIndex >= children.head.output.length) return castedTypes.toSeq + + // For the attrIndex-th attribute, find the widest type + val initialType = Option(children.head.output(attrIndex).dataType) + children.foldLeft(initialType) { (currentOutputDataTypes, child) => + (currentOutputDataTypes, child.output(attrIndex).dataType) match { + case (Some(dt1), dt2) if dt1 != dt2 => + findWiderTypeForTwo(dt1, dt2) + case (Some(dt1), dt2) if dt1 == dt2 => Option(dt1) + case other => None + } + } match { + // If unable to find an appropriate widen type for this column, return an empty Seq + case None => Seq.empty[DataType] + // Otherwise, record the result in the queue and find the type for the next column + case Some(widenType) => + castedTypes.enqueue(widenType) + getCastedTypes (children, attrIndex + 1, castedTypes) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 651e4bb8e1345..dd00c337ecd81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueri import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.planning.{Unions, ExtractFiltersAndInnerJoins} +import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 106eaa764c11c..68bbdca42ac17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.catalyst.planning +import scala.annotation.tailrec +import scala.collection.mutable + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import scala.annotation.tailrec -import scala.collection.mutable.ArrayBuffer - /** * A pattern that matches any number of project or filter operations on top of another relational * operator. All filter operators are collected and their conditions are broken up and returned @@ -179,20 +179,21 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { */ object Unions { def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match { - case u: Union => Some(collectUnionChildren(ArrayBuffer(u), Seq.empty[LogicalPlan])) + case u: Union => Some(collectUnionChildren(mutable.Stack(u), Seq.empty[LogicalPlan])) case _ => None } @tailrec private def collectUnionChildren( - plan: ArrayBuffer[LogicalPlan], + plans: mutable.Stack[LogicalPlan], children: Seq[LogicalPlan]): Seq[LogicalPlan] = { - if (plan.isEmpty) children + if (plans.isEmpty) children else { - plan.head match { + plans.pop match { case Union(grandchildren) => - collectUnionChildren(grandchildren.to[ArrayBuffer] ++ plan.tail, children) - case other => collectUnionChildren(plan.tail, children :+ other) + grandchildren.reverseMap(plans.push(_)) + collectUnionChildren(plans, children) + case other => collectUnionChildren(plans, children :+ other) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 6d15b649d0834..730d3aa7aefcc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -130,6 +130,7 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { attrs.head.withNullability(attrs.exists(_.nullable))) override lazy val resolved: Boolean = { + // allChildrenCompatible needs to be evaluated after childrenResolved lazy val allChildrenCompatible: Boolean = children.tail.forall( child => // compare the attribute number with the first child @@ -139,7 +140,7 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { case (l, r) => l.dataType == r.dataType } ) - childrenResolved && allChildrenCompatible + children.length > 1 && childrenResolved && allChildrenCompatible } override def statistics: Statistics = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index cf84855885a37..a569cb8b7e4b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -237,6 +237,12 @@ class AnalysisSuite extends AnalysisTest { checkAnalysis(plan, expected) } + test("Eliminate the unnecessary union") { + val plan = Union(testRelation :: Nil) + val expected = testRelation + checkAnalysis(plan, expected) + } + test("SPARK-12102: Ignore nullablity when comparing two sides of case") { val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false))) val plan = relation.select(CaseWhen(Seq(Literal(true), 'a, 'b)).as("val")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 0ca19f6189d94..24c608eaa5b39 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project, Union} -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.types._ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala similarity index 98% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 3e8bb662625bd..2283f7c008ba2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -class SetOperationPushDownSuite extends PlanTest { +class SetOperationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index dc6ad27472f2e..6a7709f178e26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -25,7 +25,7 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.plans.logical.{Union, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.functions._ @@ -104,7 +104,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // Before optimizer, Union should be combined. assert(unionDF.queryExecution.analyzed.collect { - case j @ Union(Seq(_, _, _, _, _)) => j }.size === 1) + case j: Union if j.children.size == 5 => j }.size === 1) checkAnswer( unionDF.agg(avg('key), max('key), min('key), sum('key)), From 3b13ddf859ddcc36a7823ff364250ff4089a8ac2 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 15 Jan 2016 14:16:04 -0800 Subject: [PATCH 22/27] address comments. --- .../spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 10 +--------- .../sql/catalyst/plans/logical/basicOperators.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../scala/org/apache/spark/sql/hive/SQLBuilder.scala | 7 ++++--- 4 files changed, 7 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 1dd70fd721cfe..5f725b08a700c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -236,15 +236,7 @@ object HiveTypeCoercion { if (attrIndex >= children.head.output.length) return castedTypes.toSeq // For the attrIndex-th attribute, find the widest type - val initialType = Option(children.head.output(attrIndex).dataType) - children.foldLeft(initialType) { (currentOutputDataTypes, child) => - (currentOutputDataTypes, child.output(attrIndex).dataType) match { - case (Some(dt1), dt2) if dt1 != dt2 => - findWiderTypeForTwo(dt1, dt2) - case (Some(dt1), dt2) if dt1 == dt2 => Option(dt1) - case other => None - } - } match { + findWiderCommonType(children.map(_.output(attrIndex).dataType)) match { // If unable to find an appropriate widen type for this column, return an empty Seq case None => Seq.empty[DataType] // Otherwise, record the result in the queue and find the type for the next column diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 730d3aa7aefcc..88b784695d693 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -131,7 +131,7 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { override lazy val resolved: Boolean = { // allChildrenCompatible needs to be evaluated after childrenResolved - lazy val allChildrenCompatible: Boolean = + def allChildrenCompatible: Boolean = children.tail.forall( child => // compare the attribute number with the first child child.output.length == children.head.output.length && diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ccb4097c17c72..741010bf3e653 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -98,7 +98,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.collect().toSeq) } - test ("union all") { + test("union all") { val unionDF = testData.unionAll(testData).unionAll(testData) .unionAll(testData).unionAll(testData) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index d490fd4920792..ce009abeb7306 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -129,15 +129,16 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi conditionSQL = condition.sql } yield s"$childSQL $whereOrHaving $conditionSQL" - case Union(children) if children.length == 1 => - toSQL(children.head) - case Union(children) if children.length > 1 => for { leftSQL <- toSQL(children.head) + // When children.tail only has one child, we will go to the next case to get rid of Union. rightSQL <- toSQL(Union(children.tail)) } yield s"$leftSQL UNION ALL $rightSQL" + case Union(children) if children.length == 1 => + toSQL(children.head) + // Persisted data source relation case Subquery(alias, LogicalRelation(_, _, Some(TableIdentifier(table, Some(database))))) => Some(s"`$database`.`$table`") From b88bdeb70e3f9a24f1b9136801ebf59a61093dd2 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 15 Jan 2016 14:20:42 -0800 Subject: [PATCH 23/27] added a comment. --- .../scala/org/apache/spark/sql/catalyst/planning/patterns.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 68bbdca42ac17..f0ee124e88a9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -183,6 +183,7 @@ object Unions { case _ => None } + // Doing a depth-first tree traversal to combine all the union children. @tailrec private def collectUnionChildren( plans: mutable.Stack[LogicalPlan], From 3041864da350b862bd9dc04565572e283d295a5b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 16 Jan 2016 01:08:52 -0800 Subject: [PATCH 24/27] address comments. --- .../sql/catalyst/analysis/HiveTypeCoercion.scala | 2 +- .../org/apache/spark/sql/hive/SQLBuilder.scala | 14 ++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 5f725b08a700c..8ff53ff86b1b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -242,7 +242,7 @@ object HiveTypeCoercion { // Otherwise, record the result in the queue and find the type for the next column case Some(widenType) => castedTypes.enqueue(widenType) - getCastedTypes (children, attrIndex + 1, castedTypes) + getCastedTypes(children, attrIndex + 1, castedTypes) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index ce009abeb7306..1654594538366 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -130,14 +130,12 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi } yield s"$childSQL $whereOrHaving $conditionSQL" case Union(children) if children.length > 1 => - for { - leftSQL <- toSQL(children.head) - // When children.tail only has one child, we will go to the next case to get rid of Union. - rightSQL <- toSQL(Union(children.tail)) - } yield s"$leftSQL UNION ALL $rightSQL" - - case Union(children) if children.length == 1 => - toSQL(children.head) + val childrenSql = children.map(toSQL(_)) + if (childrenSql.exists(_.isEmpty)) { + None + } else { + Some(childrenSql.map(_.get).mkString(" UNION ALL ")) + } // Persisted data source relation case Subquery(alias, LogicalRelation(_, _, Some(TableIdentifier(table, Some(database))))) => From 6259fd90a24428a7805569dbe1fc6bea024526c3 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 18 Jan 2016 14:12:58 -0800 Subject: [PATCH 25/27] reimplement it based on the latest change. --- .../catalyst/analysis/HiveTypeCoercion.scala | 69 ++++++++++--------- 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 8ff53ff86b1b6..2ce15696d924c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -203,32 +203,41 @@ object HiveTypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - private def widenOutputTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { - require(children.forall(_.output.length == children.head.output.length)) + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if p.analyzed => p - // Get a sequence of data types, each of which is the widest type of this specific attribute - // in all the children - val castedTypes: Seq[DataType] = - getCastedTypes(children, attrIndex = 0, mutable.Queue[DataType]()) + case s @ SetOperation(left, right) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) + assert(newChildren.length == 2) + s.makeCopy(Array(newChildren.head, newChildren.last)) - // Add extra Project for type promotion if necessary - if (castedTypes.isEmpty) children else children.map(castOutput(_, castedTypes)) + case s: Union if s.childrenResolved && + s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) + s.makeCopy(Array(newChildren)) } - // Add Project if the data types do not match - private def castOutput( - plan: LogicalPlan, - castedTypes: Seq[DataType]): LogicalPlan = { - val casted = plan.output.zip(castedTypes).map { - case (e, dt) if e.dataType != dt => - Alias(Cast(e, dt), e.name)() - case (e, _) => e + /** Build new children with the widest types for each attribute among all the children */ + private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { + require(children.forall(_.output.length == children.head.output.length)) + + // Get a sequence of data types, each of which is the widest type of this specific attribute + // in all the children + val targetTypes: Seq[DataType] = + getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]()) + + if (targetTypes.nonEmpty) { + // Add an extra Project if the targetTypes are different from the original types. + children.map(widenTypes(_, targetTypes)) + } else { + // Unable to find a target type to widen, then just return the original set. + children } - if (casted.exists(_.isInstanceOf[Alias])) Project(casted, plan) else plan } - // Get the widest type for each attribute in all the children - @tailrec private def getCastedTypes( + /** Get the widest type for each attribute in all the children */ + @tailrec private def getWidestTypes( children: Seq[LogicalPlan], attrIndex: Int, castedTypes: mutable.Queue[DataType]): Seq[DataType] = { @@ -242,23 +251,17 @@ object HiveTypeCoercion { // Otherwise, record the result in the queue and find the type for the next column case Some(widenType) => castedTypes.enqueue(widenType) - getCastedTypes(children, attrIndex + 1, castedTypes) + getWidestTypes(children, attrIndex + 1, castedTypes) } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case p if p.analyzed => p - - case s @ SetOperation(left, right) if s.childrenResolved && - left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = widenOutputTypes(left :: right :: Nil) - assert(newChildren.length == 2) - s.makeCopy(Array(newChildren.head, newChildren.last)) - - case s: Union if s.childrenResolved && - s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => - val newChildren: Seq[LogicalPlan] = widenOutputTypes(s.children) - s.makeCopy(Array(newChildren)) + /** Given a plan, add an extra project on top to widen some columns' data types. */ + private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = { + val casted = plan.output.zip(targetTypes).map { + case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() + case (e, _) => e + } + if (casted.exists(_.isInstanceOf[Alias])) Project(casted, plan) else plan } } From 4f717410527094cfe028492e82bcad2fe749ee8d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 18 Jan 2016 16:19:37 -0800 Subject: [PATCH 26/27] address comments. --- .../spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 7fc0cf3e6a09d..437c509151485 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -30,8 +30,7 @@ import org.apache.spark.sql.types._ /** -<<<<<<< HEAD - * A collection of [[Rule Rules]] that can be used to coerce differing types that participate in + * A collection of [[Rule]] that can be used to coerce differing types that participate in * operations into compatible ones. * * Most of these rules are based on Hive semantics, but they do not introduce any dependencies on @@ -45,12 +44,6 @@ import org.apache.spark.sql.types._ * some acceptable loss of precision (e.g. there is no common type for double and decimal because * double's range is larger than decimal, and yet decimal is more precise than double, but in * union we would cast the decimal into double). -======= - * A collection of [[Rule]] that can be used to coerce differing types that - * participate in operations into compatible ones. Most of these rules are based on Hive semantics, - * but they do not introduce any dependencies on the hive codebase. For this reason they remain in - * Catalyst until we have a more standard set of coercions. ->>>>>>> unionAllMCMergedNewNew */ object HiveTypeCoercion { From c63f237b6de61c700cec3544820674c1dcb94334 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 19 Jan 2016 13:41:24 -0800 Subject: [PATCH 27/27] address comments. --- .../apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 2 +- .../spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 437c509151485..c557c3231997a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -277,7 +277,7 @@ object HiveTypeCoercion { case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() case (e, _) => e } - if (casted.exists(_.isInstanceOf[Alias])) Project(casted, plan) else plan + Project(casted, plan) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 08da69297b159..c30434a0063b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -465,7 +465,7 @@ class HiveTypeCoercionSuite extends PlanTest { assert(unionRelation.children.head.isInstanceOf[Project]) assert(unionRelation.children(1).isInstanceOf[Project]) - assert(unionRelation.children(2).isInstanceOf[LocalRelation]) + assert(unionRelation.children(2).isInstanceOf[Project]) assert(unionRelation.children(3).isInstanceOf[Project]) }