Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12616] [SQL] Making Logical Operator Union Support Arbitrary Number of Children #10577

Closed
wants to merge 58 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
01e4cdf
Merge remote-tracking branch 'upstream/master'
gatorsmile Nov 13, 2015
6835704
Merge remote-tracking branch 'upstream/master'
gatorsmile Nov 14, 2015
9180687
Merge remote-tracking branch 'upstream/master'
gatorsmile Nov 14, 2015
b38a21e
SPARK-11633
gatorsmile Nov 17, 2015
d2b84af
Merge remote-tracking branch 'upstream/master' into joinMakeCopy
gatorsmile Nov 17, 2015
fda8025
Merge remote-tracking branch 'upstream/master'
gatorspark Nov 17, 2015
ac0dccd
Merge branch 'master' of https://github.com/gatorsmile/spark
gatorspark Nov 17, 2015
6e0018b
Merge remote-tracking branch 'upstream/master'
Nov 20, 2015
0546772
converge
gatorsmile Nov 20, 2015
b37a64f
converge
gatorsmile Nov 20, 2015
73270c8
added a new logical operator UNIONS
gatorsmile Jan 4, 2016
d9811c7
Merge remote-tracking branch 'upstream/master' into unionAllMultiChil…
gatorsmile Jan 4, 2016
5d031a7
remove the old operator union
gatorsmile Jan 4, 2016
c1f66f7
remove the old operator union #2.
gatorsmile Jan 5, 2016
c1dcd02
rename.
gatorsmile Jan 5, 2016
51ad5b2
address the comments.
gatorsmile Jan 6, 2016
c2a872c
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 6, 2016
5681ca8
Merge branch 'unionAllMultiChildren' into unionAllMC
gatorsmile Jan 6, 2016
7a54c8f
Change the optimizer rule for pushing Filter and Project through new …
gatorsmile Jan 6, 2016
95e2349
refactored WidenSetOperationTypes and added test cases
gatorsmile Jan 6, 2016
6a6003e
addressed comments.
gatorsmile Jan 6, 2016
15ec058
replace list by arrayBuffer in combineUnions
gatorsmile Jan 6, 2016
5e06647
address comments.
gatorsmile Jan 6, 2016
ab6dbd7
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 6, 2016
b821af0
Merge branch 'unionAllMC' into unionAllMCMerged
gatorsmile Jan 6, 2016
2229932
move changes in HiveQI.scala to CatalystQI.scala
gatorsmile Jan 6, 2016
b3327b1
add lazy.
gatorsmile Jan 6, 2016
4276356
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 6, 2016
2dab708
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 7, 2016
723c0da
resolve comments.
gatorsmile Jan 7, 2016
42b81a8
resolve comments.
gatorsmile Jan 8, 2016
4e0387f
Merge remote-tracking branch 'upstream/master' into unionAllMCMerged
gatorsmile Jan 8, 2016
b03d813
Merge remote-tracking branch 'upstream/master' into unionAllMCMerged
gatorsmile Jan 8, 2016
ab732c1
Remove the unneeded parm.
gatorsmile Jan 8, 2016
0458770
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 8, 2016
1debdfa
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 9, 2016
7320e21
Merge branch 'unionAllMCMerged' into unionAllMCMergedNew
gatorsmile Jan 9, 2016
741371a
Merge remote-tracking branch 'upstream/master' into unionAllMCMergedNew
gatorsmile Jan 9, 2016
031a5d8
changed the implementation of Union in sql generation
gatorsmile Jan 9, 2016
f3d23dc
fixed the implementation of Union in sql generation
gatorsmile Jan 9, 2016
a56e595
Merge remote-tracking branch 'upstream/master' into unionAllMCMergedNew
gatorsmile Jan 13, 2016
abfcf93
address comments.
gatorsmile Jan 14, 2016
763706d
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 14, 2016
e8e19a1
Merge branch 'unionAllMCMergedNew' into unionAllMCMergedNewNew
gatorsmile Jan 14, 2016
3b13ddf
address comments.
gatorsmile Jan 15, 2016
b88bdeb
added a comment.
gatorsmile Jan 15, 2016
3041864
address comments.
gatorsmile Jan 16, 2016
6259fd9
reimplement it based on the latest change.
gatorsmile Jan 18, 2016
4de6ec1
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 18, 2016
f112026
Merge branch 'unionAllMCMergedNewNew' into unionAllMCMergedNewNewNew
gatorsmile Jan 18, 2016
4f71741
address comments.
gatorsmile Jan 19, 2016
9422a4f
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 19, 2016
59b5895
Merge branch 'unionAllMCMergedNewNewNew' into unionAllMCMergedNewNewN…
gatorsmile Jan 19, 2016
c63f237
address comments.
gatorsmile Jan 19, 2016
2e8562d
Merge remote-tracking branch 'upstream/master' into unionAllMCMergedN…
gatorsmile Jan 20, 2016
a571998
Merge remote-tracking branch 'upstream/master' into unionAllMCMergedN…
gatorsmile Jan 20, 2016
52bdf48
Merge remote-tracking branch 'upstream/master'
gatorsmile Jan 20, 2016
c18381e
Merge branch 'unionAllMCMergedNewNewNewNew' into unionAllMCMergedNewN…
gatorsmile Jan 20, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
overwrite)
}

// If there are multiple INSERTS just UNION them together into on query.
val query = queries.reduceLeft(Union)
// If there are multiple INSERTS just UNION them together into one query.
val query = if (queries.length == 1) queries.head else Union(queries)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: you might actually just make Union smart enough to return the node itself when there is only one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry I assumed it was varargs. Maybe this is hard to do.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the merge, I am planning to introduce a new rule in Optimizer to detect if any child of Union can be pruned or if there exists only one child (e.g., after pruning), we can remove the Union operator from the logical plan. I am wondering if this sounds good to you? Thank you!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now we will eliminate one-child Union during analysis, so it's ok to just return Union here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is pretty tricky. If we add Union here, many test cases will fail even if we eliminate Union in the early stage of the analyzer. Let me show you my analysis:

In the following test case,

  test("INSERT OVERWRITE TABLE Parquet table") {
    withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") {
      withTempPath { file =>
        sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath)
        hiveContext.read.parquet(file.getCanonicalPath).registerTempTable("p")
        withTempTable("p") {
          // let's do three overwrites for good measure
          sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
          sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
          sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
          checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq)
        }
      }
    }
  }

If we always add Union in the parser as you suggested above, the testcase will not execute the statement:
sql("INSERT OVERWRITE TABLE p SELECT * FROM t")

In Dataframe.scala, we have codes to handle InsertIntoTable in an eager way. : (

  @transient protected[sql] val logicalPlan: LogicalPlan = queryExecution.logical match {
    // For various commands (like DDL) and queries with side effects, we force query optimization to
    // happen right away to let these side effects take place eagerly.
    case _: Command |
         _: InsertIntoTable |
         _: CreateTableUsingAsSelect =>
      LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
    case _ =>
      queryExecution.analyzed
  }

I assume we still want to keep this as a special case? Please let me know what you think. Thanks! @cloud-fan

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah makes sense, let's keep it


// return With plan if there is CTE
cteRelations.map(With(query, _)).getOrElse(query)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ class Analyzer(
lazy val batches: Seq[Batch] = Seq(
Batch("Substitution", fixedPoint,
CTESubstitution,
WindowsSubstitution),
WindowsSubstitution,
EliminateUnions),
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveReferences ::
Expand Down Expand Up @@ -1170,6 +1171,15 @@ object EliminateSubQueries extends Rule[LogicalPlan] {
}
}

/**
* Removes [[Union]] operators from the plan if it just has one child.
*/
object EliminateUnions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Union(children) if children.size == 1 => children.head
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can be case Union(child :: Nil) => child

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats actually only going to match when the Seq is explicitly a List.

}
}

/**
* Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level
* expression in Project(project list) or Aggregate(aggregate expressions) or
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ trait CheckAnalysis {
s"but the left table has ${left.output.length} columns and the right has " +
s"${right.output.length}")

case s: Union if s.children.exists(_.output.length != s.children.head.output.length) =>
val firstError = s.children.find(_.output.length != s.children.head.output.length).get
failAnalysis(
s"""
|Unions can only be performed on tables with the same number of columns,
| but one table has '${firstError.output.length}' columns and another table has
| '${s.children.head.output.length}' columns""".stripMargin)

case _ => // Fallbacks to the following checks
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ package org.apache.spark.sql.catalyst.analysis

import javax.annotation.Nullable

import scala.annotation.tailrec
import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
Expand All @@ -27,7 +30,7 @@ import org.apache.spark.sql.types._


/**
* A collection of [[Rule Rules]] that can be used to coerce differing types that participate in
* A collection of [[Rule]] that can be used to coerce differing types that participate in
* operations into compatible ones.
*
* Most of these rules are based on Hive semantics, but they do not introduce any dependencies on
Expand Down Expand Up @@ -219,31 +222,59 @@ object HiveTypeCoercion {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p if p.analyzed => p

case s @ SetOperation(left, right) if s.childrenResolved
&& left.output.length == right.output.length && !s.resolved =>
case s @ SetOperation(left, right) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
assert(newChildren.length == 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? Looks like it is impossible to have different length here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it is impossible now, but this relies on the implementation of widenOutputTypes. When the others do a code change in the future, it might break the assumption and cause a bug in the next line.

s.makeCopy(Array(newChildren.head, newChildren.last))

// Tracks the list of data types to widen.
// Some(dataType) means the right-hand side and the left-hand side have different types,
// and there is a target type to widen both sides to.
val targetTypes: Seq[Option[DataType]] = left.output.zip(right.output).map {
case (lhs, rhs) if lhs.dataType != rhs.dataType =>
findWiderTypeForTwo(lhs.dataType, rhs.dataType)
case other => None
}
case s: Union if s.childrenResolved &&
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
s.makeCopy(Array(newChildren))
}

if (targetTypes.exists(_.isDefined)) {
// There is at least one column to widen.
s.makeCopy(Array(widenTypes(left, targetTypes), widenTypes(right, targetTypes)))
} else {
// If we cannot find any column to widen, then just return the original set.
s
}
/** Build new children with the widest types for each attribute among all the children */
private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
require(children.forall(_.output.length == children.head.output.length))

// Get a sequence of data types, each of which is the widest type of this specific attribute
// in all the children
val targetTypes: Seq[DataType] =
getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]())

if (targetTypes.nonEmpty) {
// Add an extra Project if the targetTypes are different from the original types.
children.map(widenTypes(_, targetTypes))
} else {
// Unable to find a target type to widen, then just return the original set.
children
}
}

/** Get the widest type for each attribute in all the children */
@tailrec private def getWidestTypes(
children: Seq[LogicalPlan],
attrIndex: Int,
castedTypes: mutable.Queue[DataType]): Seq[DataType] = {
// Return the result after the widen data types have been found for all the children
if (attrIndex >= children.head.output.length) return castedTypes.toSeq

// For the attrIndex-th attribute, find the widest type
findWiderCommonType(children.map(_.output(attrIndex).dataType)) match {
// If unable to find an appropriate widen type for this column, return an empty Seq
case None => Seq.empty[DataType]
// Otherwise, record the result in the queue and find the type for the next column
case Some(widenType) =>
castedTypes.enqueue(widenType)
getWidestTypes(children, attrIndex + 1, castedTypes)
}
}

/** Given a plan, add an extra project on top to widen some columns' data types. */
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[Option[DataType]]): LogicalPlan = {
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = {
val casted = plan.output.zip(targetTypes).map {
case (e, Some(dt)) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
case (e, _) => e
}
Project(casted, plan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueri
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions}
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
Expand All @@ -45,6 +45,13 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
//////////////////////////////////////////////////////////////////////////////////////////
// - Do the first call of CombineUnions before starting the major Optimizer rules,
// since it can reduce the number of iteration and the other rules could add/move
// extra operators between two adjacent Union operators.
// - Call CombineUnions again in Batch("Operator Optimizations"),
// since the other rules might make two separate Unions operators adjacent.
Batch("Union", Once,
CombineUnions) ::
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a case that we must call CombineUnions at the very beginning?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is done for performance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(we should document it inline)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will also add performance issue in the comment.

@cloud-fan #10451 This PR will add extra Limit through Union. I guess, if we do multiple Union All with one Limit in SQL, the plan could be changed like

Limit 
Union 
Limit 
Union 
Limit
...

Batch("Aggregate", FixedPoint(100),
ReplaceDistinctWithAggregate,
RemoveLiteralFromGroupExpressions) ::
Expand All @@ -62,6 +69,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
ProjectCollapsing,
CombineFilters,
CombineLimits,
CombineUnions,
// Constant folding and strength reduction
NullPropagation,
OptimizeIn,
Expand Down Expand Up @@ -138,11 +146,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
/**
* Maps Attributes from the left side to the corresponding Attribute on the right side.
*/
private def buildRewrites(bn: BinaryNode): AttributeMap[Attribute] = {
assert(bn.isInstanceOf[Union] || bn.isInstanceOf[Intersect] || bn.isInstanceOf[Except])
assert(bn.left.output.size == bn.right.output.size)

AttributeMap(bn.left.output.zip(bn.right.output))
private def buildRewrites(left: LogicalPlan, right: LogicalPlan): AttributeMap[Attribute] = {
assert(left.output.size == right.output.size)
AttributeMap(left.output.zip(right.output))
}

/**
Expand Down Expand Up @@ -176,32 +182,38 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Push down filter into union
case Filter(condition, u @ Union(left, right)) =>
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
val rewrites = buildRewrites(u)
Filter(nondeterministic,
Union(
Filter(deterministic, left),
Filter(pushToRight(deterministic, rewrites), right)
)
)

// Push down deterministic projection through UNION ALL
case p @ Project(projectList, u @ Union(left, right)) =>
case p @ Project(projectList, Union(children)) =>
assert(children.nonEmpty)
if (projectList.forall(_.deterministic)) {
val rewrites = buildRewrites(u)
Union(
Project(projectList, left),
Project(projectList.map(pushToRight(_, rewrites)), right))
val newFirstChild = Project(projectList, children.head)
val newOtherChildren = children.tail.map ( child => {
val rewrites = buildRewrites(children.head, child)
Project(projectList.map(pushToRight(_, rewrites)), child)
} )
Union(newFirstChild +: newOtherChildren)
} else {
p
}

// Push down filter into union
case Filter(condition, Union(children)) =>
assert(children.nonEmpty)
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
val newFirstChild = Filter(deterministic, children.head)
val newOtherChildren = children.tail.map {
child => {
val rewrites = buildRewrites(children.head, child)
Filter(pushToRight(deterministic, rewrites), child)
}
}
Filter(nondeterministic, Union(newFirstChild +: newOtherChildren))

// Push down filter through INTERSECT
case Filter(condition, i @ Intersect(left, right)) =>
case Filter(condition, Intersect(left, right)) =>
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
val rewrites = buildRewrites(i)
val rewrites = buildRewrites(left, right)
Filter(nondeterministic,
Intersect(
Filter(deterministic, left),
Expand All @@ -210,9 +222,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
)

// Push down filter through EXCEPT
case Filter(condition, e @ Except(left, right)) =>
case Filter(condition, Except(left, right)) =>
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
val rewrites = buildRewrites(e)
val rewrites = buildRewrites(left, right)
Filter(nondeterministic,
Except(
Filter(deterministic, left),
Expand Down Expand Up @@ -662,6 +674,15 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
}
}

/**
* Combines all adjacent [[Union]] operators into a single [[Union]].
*/
object CombineUnions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Unions(children) => Union(children)
}
}

/**
* Combines two adjacent [[Filter]] operators into one, merging the
* conditions into one conjunctive predicate.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.catalyst.planning

import scala.annotation.tailrec
import scala.collection.mutable

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
Expand Down Expand Up @@ -170,17 +173,29 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {
}
}


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I would get rid of this, just use it in your optimization rule.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will reimplement it using this way.

/**
* A pattern that collects all adjacent unions and returns their children as a Seq.
*/
object Unions {
def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match {
case u: Union => Some(collectUnionChildren(u))
case u: Union => Some(collectUnionChildren(mutable.Stack(u), Seq.empty[LogicalPlan]))
case _ => None
}

private def collectUnionChildren(plan: LogicalPlan): Seq[LogicalPlan] = plan match {
case Union(l, r) => collectUnionChildren(l) ++ collectUnionChildren(r)
case other => other :: Nil
// Doing a depth-first tree traversal to combine all the union children.
@tailrec
private def collectUnionChildren(
plans: mutable.Stack[LogicalPlan],
children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
if (plans.isEmpty) children
else {
plans.pop match {
case Union(grandchildren) =>
grandchildren.reverseMap(plans.push(_))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of reverse here, why not just use Queue?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a stack is for doing a depth-first tree traversal. For example, the users might expect the order of unions in the following two cases should be the same? Or they might not care it?

    val unionQuery1 = Union(Union(testRelation, testRelation2), testRelation)
    val unionQuery2 = Union(testRelation, Union(testRelation2, testRelation))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

order do matters, you are right

collectUnionChildren(plans, children)
case other => collectUnionChildren(plans, children :+ other)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,19 +101,6 @@ private[sql] object SetOperation {
def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right))
}

case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {

override def output: Seq[Attribute] =
left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable)
}

override def statistics: Statistics = {
val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes
Statistics(sizeInBytes = sizeInBytes)
}
}

case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {

override def output: Seq[Attribute] =
Expand All @@ -127,6 +114,40 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
override def output: Seq[Attribute] = left.output
}

/** Factory for constructing new `Union` nodes. */
object Union {
def apply(left: LogicalPlan, right: LogicalPlan): Union = {
Union (left :: right :: Nil)
}
}

case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {

// updating nullability to make all the children consistent
override def output: Seq[Attribute] =
children.map(_.output).transpose.map(attrs =>
attrs.head.withNullability(attrs.exists(_.nullable)))

override lazy val resolved: Boolean = {
// allChildrenCompatible needs to be evaluated after childrenResolved
def allChildrenCompatible: Boolean =
children.tail.forall( child =>
// compare the attribute number with the first child
child.output.length == children.head.output.length &&
// compare the data types with the first child
child.output.zip(children.head.output).forall {
case (l, r) => l.dataType == r.dataType }
)

children.length > 1 && childrenResolved && allChildrenCompatible
}

override def statistics: Statistics = {
val sizeInBytes = children.map(_.statistics.sizeInBytes).sum
Statistics(sizeInBytes = sizeInBytes)
}
}

case class Join(
left: LogicalPlan,
right: LogicalPlan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@ class AnalysisSuite extends AnalysisTest {
checkAnalysis(plan, expected)
}

test("Eliminate the unnecessary union") {
val plan = Union(testRelation :: Nil)
val expected = testRelation
checkAnalysis(plan, expected)
}

test("SPARK-12102: Ignore nullablity when comparing two sides of case") {
val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false)))
val plan = relation.select(CaseWhen(Seq((Literal(true), 'a.attr)), 'b).as("val"))
Expand Down
Loading