diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 5fb41f7e4bf8c..35273c7e24ae4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -402,8 +402,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C overwrite) } - // If there are multiple INSERTS just UNION them together into on query. - val query = queries.reduceLeft(Union) + // If there are multiple INSERTS just UNION them together into one query. + val query = if (queries.length == 1) queries.head else Union(queries) // return With plan if there is CTE cteRelations.map(With(query, _)).getOrElse(query) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d4b4bc88b3f2f..33d76eeb21287 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -66,7 +66,8 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, CTESubstitution, - WindowsSubstitution), + WindowsSubstitution, + EliminateUnions), Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: @@ -1170,6 +1171,15 @@ object EliminateSubQueries extends Rule[LogicalPlan] { } } +/** + * Removes [[Union]] operators from the plan if it just has one child. + */ +object EliminateUnions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Union(children) if children.size == 1 => children.head + } +} + /** * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level * expression in Project(project list) or Aggregate(aggregate expressions) or diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 2a2e0d27d9435..f2e78d97442e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -189,6 +189,14 @@ trait CheckAnalysis { s"but the left table has ${left.output.length} columns and the right has " + s"${right.output.length}") + case s: Union if s.children.exists(_.output.length != s.children.head.output.length) => + val firstError = s.children.find(_.output.length != s.children.head.output.length).get + failAnalysis( + s""" + |Unions can only be performed on tables with the same number of columns, + | but one table has '${firstError.output.length}' columns and another table has + | '${s.children.head.output.length}' columns""".stripMargin) + case _ => // Fallbacks to the following checks } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 7df3787e6d2d3..c557c3231997a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable +import scala.annotation.tailrec +import scala.collection.mutable + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -27,7 +30,7 @@ import org.apache.spark.sql.types._ /** - * A collection of [[Rule Rules]] that can be used to coerce differing types that participate in + * A collection of [[Rule]] that can be used to coerce differing types that participate in * operations into compatible ones. * * Most of these rules are based on Hive semantics, but they do not introduce any dependencies on @@ -219,31 +222,59 @@ object HiveTypeCoercion { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if p.analyzed => p - case s @ SetOperation(left, right) if s.childrenResolved - && left.output.length == right.output.length && !s.resolved => + case s @ SetOperation(left, right) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) + assert(newChildren.length == 2) + s.makeCopy(Array(newChildren.head, newChildren.last)) - // Tracks the list of data types to widen. - // Some(dataType) means the right-hand side and the left-hand side have different types, - // and there is a target type to widen both sides to. - val targetTypes: Seq[Option[DataType]] = left.output.zip(right.output).map { - case (lhs, rhs) if lhs.dataType != rhs.dataType => - findWiderTypeForTwo(lhs.dataType, rhs.dataType) - case other => None - } + case s: Union if s.childrenResolved && + s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) + s.makeCopy(Array(newChildren)) + } - if (targetTypes.exists(_.isDefined)) { - // There is at least one column to widen. - s.makeCopy(Array(widenTypes(left, targetTypes), widenTypes(right, targetTypes))) - } else { - // If we cannot find any column to widen, then just return the original set. - s - } + /** Build new children with the widest types for each attribute among all the children */ + private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { + require(children.forall(_.output.length == children.head.output.length)) + + // Get a sequence of data types, each of which is the widest type of this specific attribute + // in all the children + val targetTypes: Seq[DataType] = + getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]()) + + if (targetTypes.nonEmpty) { + // Add an extra Project if the targetTypes are different from the original types. + children.map(widenTypes(_, targetTypes)) + } else { + // Unable to find a target type to widen, then just return the original set. + children + } + } + + /** Get the widest type for each attribute in all the children */ + @tailrec private def getWidestTypes( + children: Seq[LogicalPlan], + attrIndex: Int, + castedTypes: mutable.Queue[DataType]): Seq[DataType] = { + // Return the result after the widen data types have been found for all the children + if (attrIndex >= children.head.output.length) return castedTypes.toSeq + + // For the attrIndex-th attribute, find the widest type + findWiderCommonType(children.map(_.output(attrIndex).dataType)) match { + // If unable to find an appropriate widen type for this column, return an empty Seq + case None => Seq.empty[DataType] + // Otherwise, record the result in the queue and find the type for the next column + case Some(widenType) => + castedTypes.enqueue(widenType) + getWidestTypes(children, attrIndex + 1, castedTypes) + } } /** Given a plan, add an extra project on top to widen some columns' data types. */ - private def widenTypes(plan: LogicalPlan, targetTypes: Seq[Option[DataType]]): LogicalPlan = { + private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = { val casted = plan.output.zip(targetTypes).map { - case (e, Some(dt)) if e.dataType != dt => Alias(Cast(e, dt), e.name)() + case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() case (e, _) => e } Project(casted, plan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 04643f0274bd4..44455b482074b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueri import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -45,6 +45,13 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here ////////////////////////////////////////////////////////////////////////////////////////// + // - Do the first call of CombineUnions before starting the major Optimizer rules, + // since it can reduce the number of iteration and the other rules could add/move + // extra operators between two adjacent Union operators. + // - Call CombineUnions again in Batch("Operator Optimizations"), + // since the other rules might make two separate Unions operators adjacent. + Batch("Union", Once, + CombineUnions) :: Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: @@ -62,6 +69,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { ProjectCollapsing, CombineFilters, CombineLimits, + CombineUnions, // Constant folding and strength reduction NullPropagation, OptimizeIn, @@ -138,11 +146,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { /** * Maps Attributes from the left side to the corresponding Attribute on the right side. */ - private def buildRewrites(bn: BinaryNode): AttributeMap[Attribute] = { - assert(bn.isInstanceOf[Union] || bn.isInstanceOf[Intersect] || bn.isInstanceOf[Except]) - assert(bn.left.output.size == bn.right.output.size) - - AttributeMap(bn.left.output.zip(bn.right.output)) + private def buildRewrites(left: LogicalPlan, right: LogicalPlan): AttributeMap[Attribute] = { + assert(left.output.size == right.output.size) + AttributeMap(left.output.zip(right.output)) } /** @@ -176,32 +182,38 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Push down filter into union - case Filter(condition, u @ Union(left, right)) => - val (deterministic, nondeterministic) = partitionByDeterministic(condition) - val rewrites = buildRewrites(u) - Filter(nondeterministic, - Union( - Filter(deterministic, left), - Filter(pushToRight(deterministic, rewrites), right) - ) - ) // Push down deterministic projection through UNION ALL - case p @ Project(projectList, u @ Union(left, right)) => + case p @ Project(projectList, Union(children)) => + assert(children.nonEmpty) if (projectList.forall(_.deterministic)) { - val rewrites = buildRewrites(u) - Union( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) + val newFirstChild = Project(projectList, children.head) + val newOtherChildren = children.tail.map ( child => { + val rewrites = buildRewrites(children.head, child) + Project(projectList.map(pushToRight(_, rewrites)), child) + } ) + Union(newFirstChild +: newOtherChildren) } else { p } + // Push down filter into union + case Filter(condition, Union(children)) => + assert(children.nonEmpty) + val (deterministic, nondeterministic) = partitionByDeterministic(condition) + val newFirstChild = Filter(deterministic, children.head) + val newOtherChildren = children.tail.map { + child => { + val rewrites = buildRewrites(children.head, child) + Filter(pushToRight(deterministic, rewrites), child) + } + } + Filter(nondeterministic, Union(newFirstChild +: newOtherChildren)) + // Push down filter through INTERSECT - case Filter(condition, i @ Intersect(left, right)) => + case Filter(condition, Intersect(left, right)) => val (deterministic, nondeterministic) = partitionByDeterministic(condition) - val rewrites = buildRewrites(i) + val rewrites = buildRewrites(left, right) Filter(nondeterministic, Intersect( Filter(deterministic, left), @@ -210,9 +222,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { ) // Push down filter through EXCEPT - case Filter(condition, e @ Except(left, right)) => + case Filter(condition, Except(left, right)) => val (deterministic, nondeterministic) = partitionByDeterministic(condition) - val rewrites = buildRewrites(e) + val rewrites = buildRewrites(left, right) Filter(nondeterministic, Except( Filter(deterministic, left), @@ -662,6 +674,15 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Combines all adjacent [[Union]] operators into a single [[Union]]. + */ +object CombineUnions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Unions(children) => Union(children) + } +} + /** * Combines two adjacent [[Filter]] operators into one, merging the * conditions into one conjunctive predicate. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index cd3f15cbe107b..f0ee124e88a9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.planning +import scala.annotation.tailrec +import scala.collection.mutable + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -170,17 +173,29 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { } } + /** * A pattern that collects all adjacent unions and returns their children as a Seq. */ object Unions { def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match { - case u: Union => Some(collectUnionChildren(u)) + case u: Union => Some(collectUnionChildren(mutable.Stack(u), Seq.empty[LogicalPlan])) case _ => None } - private def collectUnionChildren(plan: LogicalPlan): Seq[LogicalPlan] = plan match { - case Union(l, r) => collectUnionChildren(l) ++ collectUnionChildren(r) - case other => other :: Nil + // Doing a depth-first tree traversal to combine all the union children. + @tailrec + private def collectUnionChildren( + plans: mutable.Stack[LogicalPlan], + children: Seq[LogicalPlan]): Seq[LogicalPlan] = { + if (plans.isEmpty) children + else { + plans.pop match { + case Union(grandchildren) => + grandchildren.reverseMap(plans.push(_)) + collectUnionChildren(plans, children) + case other => collectUnionChildren(plans, children :+ other) + } + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index f4a3d85d2a8a4..e9c970cd08088 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -101,19 +101,6 @@ private[sql] object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) } -case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { - - override def output: Seq[Attribute] = - left.output.zip(right.output).map { case (leftAttr, rightAttr) => - leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable) - } - - override def statistics: Statistics = { - val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes - Statistics(sizeInBytes = sizeInBytes) - } -} - case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { override def output: Seq[Attribute] = @@ -127,6 +114,40 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le override def output: Seq[Attribute] = left.output } +/** Factory for constructing new `Union` nodes. */ +object Union { + def apply(left: LogicalPlan, right: LogicalPlan): Union = { + Union (left :: right :: Nil) + } +} + +case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { + + // updating nullability to make all the children consistent + override def output: Seq[Attribute] = + children.map(_.output).transpose.map(attrs => + attrs.head.withNullability(attrs.exists(_.nullable))) + + override lazy val resolved: Boolean = { + // allChildrenCompatible needs to be evaluated after childrenResolved + def allChildrenCompatible: Boolean = + children.tail.forall( child => + // compare the attribute number with the first child + child.output.length == children.head.output.length && + // compare the data types with the first child + child.output.zip(children.head.output).forall { + case (l, r) => l.dataType == r.dataType } + ) + + children.length > 1 && childrenResolved && allChildrenCompatible + } + + override def statistics: Statistics = { + val sizeInBytes = children.map(_.statistics.sizeInBytes).sum + Statistics(sizeInBytes = sizeInBytes) + } +} + case class Join( left: LogicalPlan, right: LogicalPlan, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 975cd87d090e4..ab680282208c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -237,6 +237,12 @@ class AnalysisSuite extends AnalysisTest { checkAnalysis(plan, expected) } + test("Eliminate the unnecessary union") { + val plan = Union(testRelation :: Nil) + val expected = testRelation + checkAnalysis(plan, expected) + } + test("SPARK-12102: Ignore nullablity when comparing two sides of case") { val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false))) val plan = relation.select(CaseWhen(Seq((Literal(true), 'a.attr)), 'b).as("val")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 39c8f56c1bca6..24c608eaa5b39 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -70,7 +70,7 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { Union(Project(Seq(Alias(left, "l")()), relation), Project(Seq(Alias(right, "r")()), relation)) val (l, r) = analyzer.execute(plan).collect { - case Union(left, right) => (left.output.head, right.output.head) + case Union(Seq(child1, child2)) => (child1.output.head, child2.output.head) }.head assert(l.dataType === expectedType) assert(r.dataType === expectedType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index b326aa9c55992..c30434a0063b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -387,19 +387,19 @@ class HiveTypeCoercionSuite extends PlanTest { ) } - test("WidenSetOperationTypes for union, except, and intersect") { - def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { - logical.output.zip(expectTypes).foreach { case (attr, dt) => - assert(attr.dataType === dt) - } + private def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { + logical.output.zip(expectTypes).foreach { case (attr, dt) => + assert(attr.dataType === dt) } + } - val left = LocalRelation( + test("WidenSetOperationTypes for except and intersect") { + val firstTable = LocalRelation( AttributeReference("i", IntegerType)(), AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), AttributeReference("b", ByteType)(), AttributeReference("d", DoubleType)()) - val right = LocalRelation( + val secondTable = LocalRelation( AttributeReference("s", StringType)(), AttributeReference("d", DecimalType(2, 1))(), AttributeReference("f", FloatType)(), @@ -408,15 +408,65 @@ class HiveTypeCoercionSuite extends PlanTest { val wt = HiveTypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val r1 = wt(Union(left, right)).asInstanceOf[Union] - val r2 = wt(Except(left, right)).asInstanceOf[Except] - val r3 = wt(Intersect(left, right)).asInstanceOf[Intersect] + val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except] + val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] checkOutput(r1.left, expectedTypes) checkOutput(r1.right, expectedTypes) checkOutput(r2.left, expectedTypes) checkOutput(r2.right, expectedTypes) - checkOutput(r3.left, expectedTypes) - checkOutput(r3.right, expectedTypes) + + // Check if a Project is added + assert(r1.left.isInstanceOf[Project]) + assert(r1.right.isInstanceOf[Project]) + assert(r2.left.isInstanceOf[Project]) + assert(r2.right.isInstanceOf[Project]) + + val r3 = wt(Except(firstTable, firstTable)).asInstanceOf[Except] + checkOutput(r3.left, Seq(IntegerType, DecimalType.SYSTEM_DEFAULT, ByteType, DoubleType)) + checkOutput(r3.right, Seq(IntegerType, DecimalType.SYSTEM_DEFAULT, ByteType, DoubleType)) + + // Check if no Project is added + assert(r3.left.isInstanceOf[LocalRelation]) + assert(r3.right.isInstanceOf[LocalRelation]) + } + + test("WidenSetOperationTypes for union") { + val firstTable = LocalRelation( + AttributeReference("i", IntegerType)(), + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("b", ByteType)(), + AttributeReference("d", DoubleType)()) + val secondTable = LocalRelation( + AttributeReference("s", StringType)(), + AttributeReference("d", DecimalType(2, 1))(), + AttributeReference("f", FloatType)(), + AttributeReference("l", LongType)()) + val thirdTable = LocalRelation( + AttributeReference("m", StringType)(), + AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("p", FloatType)(), + AttributeReference("q", DoubleType)()) + val forthTable = LocalRelation( + AttributeReference("m", StringType)(), + AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("p", ByteType)(), + AttributeReference("q", DoubleType)()) + + val wt = HiveTypeCoercion.WidenSetOperationTypes + val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) + + val unionRelation = wt( + Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union] + assert(unionRelation.children.length == 4) + checkOutput(unionRelation.children.head, expectedTypes) + checkOutput(unionRelation.children(1), expectedTypes) + checkOutput(unionRelation.children(2), expectedTypes) + checkOutput(unionRelation.children(3), expectedTypes) + + assert(unionRelation.children.head.isInstanceOf[Project]) + assert(unionRelation.children(1).isInstanceOf[Project]) + assert(unionRelation.children(2).isInstanceOf[Project]) + assert(unionRelation.children(3).isInstanceOf[Project]) } test("Transform Decimal precision/scale for union except and intersect") { @@ -438,8 +488,8 @@ class HiveTypeCoercionSuite extends PlanTest { val r2 = dp(Except(left1, right1)).asInstanceOf[Except] val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect] - checkOutput(r1.left, expectedType1) - checkOutput(r1.right, expectedType1) + checkOutput(r1.children.head, expectedType1) + checkOutput(r1.children.last, expectedType1) checkOutput(r2.left, expectedType1) checkOutput(r2.right, expectedType1) checkOutput(r3.left, expectedType1) @@ -459,7 +509,7 @@ class HiveTypeCoercionSuite extends PlanTest { val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except] val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect] - checkOutput(r1.right, Seq(expectedType)) + checkOutput(r1.children.last, Seq(expectedType)) checkOutput(r2.right, Seq(expectedType)) checkOutput(r3.right, Seq(expectedType)) @@ -467,7 +517,7 @@ class HiveTypeCoercionSuite extends PlanTest { val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except] val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect] - checkOutput(r4.left, Seq(expectedType)) + checkOutput(r4.children.last, Seq(expectedType)) checkOutput(r5.left, Seq(expectedType)) checkOutput(r6.left, Seq(expectedType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala similarity index 70% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index a498b463a69e9..2283f7c008ba2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -24,48 +24,73 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -class SetOperationPushDownSuite extends PlanTest { +class SetOperationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, EliminateSubQueries) :: Batch("Union Pushdown", Once, + CombineUnions, SetOperationPushDown, SimplifyFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) - val testUnion = Union(testRelation, testRelation2) + val testRelation3 = LocalRelation('g.int, 'h.int, 'i.int) + val testUnion = Union(testRelation :: testRelation2 :: testRelation3 :: Nil) val testIntersect = Intersect(testRelation, testRelation2) val testExcept = Except(testRelation, testRelation2) - test("union/intersect/except: filter to each side") { - val unionQuery = testUnion.where('a === 1) + test("union: combine unions into one unions") { + val unionQuery1 = Union(Union(testRelation, testRelation2), testRelation) + val unionQuery2 = Union(testRelation, Union(testRelation2, testRelation)) + val unionOptimized1 = Optimize.execute(unionQuery1.analyze) + val unionOptimized2 = Optimize.execute(unionQuery2.analyze) + + comparePlans(unionOptimized1, unionOptimized2) + + val combinedUnions = Union(unionOptimized1 :: unionOptimized2 :: Nil) + val combinedUnionsOptimized = Optimize.execute(combinedUnions.analyze) + val unionQuery3 = Union(unionQuery1, unionQuery2) + val unionOptimized3 = Optimize.execute(unionQuery3.analyze) + comparePlans(combinedUnionsOptimized, unionOptimized3) + } + + test("intersect/except: filter to each side") { val intersectQuery = testIntersect.where('b < 10) val exceptQuery = testExcept.where('c >= 5) - val unionOptimized = Optimize.execute(unionQuery.analyze) val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - val unionCorrectAnswer = - Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze val intersectCorrectAnswer = Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze val exceptCorrectAnswer = Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze - comparePlans(unionOptimized, unionCorrectAnswer) comparePlans(intersectOptimized, intersectCorrectAnswer) comparePlans(exceptOptimized, exceptCorrectAnswer) } + test("union: filter to each side") { + val unionQuery = testUnion.where('a === 1) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.where('a === 1) :: + testRelation2.where('d === 1) :: + testRelation3.where('g === 1) :: Nil).analyze + + comparePlans(unionOptimized, unionCorrectAnswer) + } + test("union: project to each side") { val unionQuery = testUnion.select('a) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Union(testRelation.select('a), testRelation2.select('d)).analyze + Union(testRelation.select('a) :: + testRelation2.select('d) :: + testRelation3.select('g) :: Nil).analyze comparePlans(unionOptimized, unionCorrectAnswer) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 95e5fbb11900e..518f9dcf94a70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} @@ -1002,7 +1003,9 @@ class DataFrame private[sql]( * @since 1.3.0 */ def unionAll(other: DataFrame): DataFrame = withPlan { - Union(logicalPlan, other.logicalPlan) + // This breaks caching, but it's usually ok because it addresses a very specific use case: + // using union to union many files or partitions. + CombineUnions(Union(logicalPlan, other.logicalPlan)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 9a9f7d111cf4b..bd99c399571c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ -import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ +import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} @@ -603,7 +604,11 @@ class Dataset[T] private[sql]( * duplicate items. As such, it is analogous to `UNION ALL` in SQL. * @since 1.6.0 */ - def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union) + def union(other: Dataset[T]): Dataset[T] = withPlan[T](other){ (left, right) => + // This breaks caching, but it's usually ok because it addresses a very specific use case: + // using union to union many files or partitions. + CombineUnions(Union(left, right)) + } /** * Returns a new [[Dataset]] where any elements present in `other` have been removed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c4ddb6d76b2c6..60fbb595e5758 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -336,7 +336,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { LocalTableScan(output, data) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil - case Unions(unionChildren) => + case logical.Union(unionChildren) => execution.Union(unionChildren.map(planLater)) :: Nil case logical.Except(left, right) => execution.Except(planLater(left), planLater(right)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 9e2e0357c65f0..6deb72adad5ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -281,13 +281,10 @@ case class Range( * Union two plans, without a distinct. This is UNION ALL in SQL. */ case class Union(children: Seq[SparkPlan]) extends SparkPlan { - override def output: Seq[Attribute] = { - children.tail.foldLeft(children.head.output) { case (currentOutput, child) => - currentOutput.zip(child.output).map { case (a1, a2) => - a1.withNullability(a1.nullable || a2.nullable) - } - } - } + override def output: Seq[Attribute] = + children.map(_.output).transpose.map(attrs => + attrs.head.withNullability(attrs.exists(_.nullable))) + protected override def doExecute(): RDD[InternalRow] = sparkContext.union(children.map(_.execute())) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 1a3df1b117b68..3c0f25a5dc535 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -298,9 +298,9 @@ public void testSetOperation() { Dataset intersected = ds.intersect(ds2); Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); - Dataset unioned = ds.union(ds2); + Dataset unioned = ds.union(ds2).union(ds); Assert.assertEquals( - Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo"), + Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo", "abc", "abc", "xyz"), unioned.collectAsList()); Dataset subtracted = ds.subtract(ds2); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index bd11a387a1d5d..09bbe57a43ceb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -25,7 +25,7 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.functions._ @@ -98,6 +98,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.collect().toSeq) } + test("union all") { + val unionDF = testData.unionAll(testData).unionAll(testData) + .unionAll(testData).unionAll(testData) + + // Before optimizer, Union should be combined. + assert(unionDF.queryExecution.analyzed.collect { + case j: Union if j.children.size == 5 => j }.size === 1) + + checkAnswer( + unionDF.agg(avg('key), max('key), min('key), sum('key)), + Row(50.5, 100, 1, 25250) :: Nil + ) + } + test("empty data frame") { assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(sqlContext.emptyDataFrame.count() === 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 49feeaf17d68f..8fca5e2167d04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -51,18 +51,6 @@ class PlannerSuite extends SharedSQLContext { s"The plan of query $query does not have partial aggregations.") } - test("unions are collapsed") { - val planner = sqlContext.planner - import planner._ - val query = testData.unionAll(testData).unionAll(testData).logicalPlan - val planned = BasicOperators(query).head - val logicalUnions = query collect { case u: logical.Union => u } - val physicalUnions = planned collect { case u: execution.Union => u } - - assert(logicalUnions.size === 2) - assert(physicalUnions.size === 1) - } - test("count is partially aggregated") { val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed testPartialAggregationPlan(query) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index e83b4bffff857..1654594538366 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -129,11 +129,13 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi conditionSQL = condition.sql } yield s"$childSQL $whereOrHaving $conditionSQL" - case Union(left, right) => - for { - leftSQL <- toSQL(left) - rightSQL <- toSQL(right) - } yield s"$leftSQL UNION ALL $rightSQL" + case Union(children) if children.length > 1 => + val childrenSql = children.map(toSQL(_)) + if (childrenSql.exists(_.isEmpty)) { + None + } else { + Some(childrenSql.map(_.get).mkString(" UNION ALL ")) + } // Persisted data source relation case Subquery(alias, LogicalRelation(_, _, Some(TableIdentifier(table, Some(database))))) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 0604d9f47c5da..261a4746f4287 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -105,6 +105,10 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl("SELECT id FROM t0 UNION ALL SELECT CAST(id AS INT) AS id FROM t0") } + test("three-child union") { + checkHiveQl("SELECT id FROM t0 UNION ALL SELECT id FROM t0 UNION ALL SELECT id FROM t0") + } + test("case") { checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM t0") }