Skip to content

Commit

Permalink
remove the old operator union #2.
Browse files Browse the repository at this point in the history
  • Loading branch information
gatorsmile committed Jan 5, 2016
1 parent 5d031a7 commit c1f66f7
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,16 @@ trait CheckAnalysis {
s"but the left table has ${left.output.length} columns and the right has " +
s"${right.output.length}")

case s: Unions if s.children.exists(_.output.length != s.children.head.output.length) =>
s.children.filter(_.output.length != s.children.head.output.length).foreach { child =>
failAnalysis(
s"""
|Unions can only be performed on tables with the same number of columns,
| but the table '${child.simpleString}' has '${child.output.length}' columns
| and the first table '${s.children.head.simpleString}' has
| '${s.children.head.output.length}' columns""".stripMargin)
}

case _ => // Fallbacks to the following checks
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ object HiveTypeCoercion {
* - LongType to DoubleType
* - DecimalType to Double
*
* This rule is only applied to Union/Except/Intersect
* This rule is only applied to Unions/Except/Intersect
*/
object WidenSetOperationTypes extends Rule[LogicalPlan] {

Expand All @@ -212,29 +212,59 @@ object HiveTypeCoercion {
case other => None
}

def castOutput(plan: LogicalPlan): LogicalPlan = {
val casted = plan.output.zip(castedTypes).map {
case (e, Some(dt)) if e.dataType != dt =>
Alias(Cast(e, dt), e.name)()
case (e, _) => e
}
Project(casted, plan)
if (castedTypes.exists(_.isDefined)) {
(castOutput(left, castedTypes), castOutput(right, castedTypes))
} else {
(left, right)
}
}

private[this] def widenOutputTypes(
planName: String,
children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
require(children.forall(_.output.length == children.head.output.length))

val castedTypes: Seq[Option[DataType]] =
children.tail.foldLeft(children.head.output.map(a => Option(a.dataType))) {
case (currentOutputDataTypes, child) => {
currentOutputDataTypes.zip(child.output).map {
case (Some(dt), a2) if dt != a2.dataType =>
findWiderTypeForTwo(dt, a2.dataType)
case other => None
}
}
}

if (castedTypes.exists(_.isDefined)) {
(castOutput(left), castOutput(right))
children.map(castOutput(_, castedTypes))
} else {
(left, right)
children
}
}

private[this] def castOutput(
plan: LogicalPlan,
castedTypes: Seq[Option[DataType]]): LogicalPlan = {
val casted = plan.output.zip(castedTypes).map {
case (e, Some(dt)) if e.dataType != dt =>
Alias(Cast(e, dt), e.name)()
case (e, _) => e
}
Project(casted, plan)
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p if p.analyzed => p

case s @ SetOperation(left, right) if s.childrenResolved
&& left.output.length == right.output.length && !s.resolved =>
val (newLeft, newRight) = widenOutputTypes(s.nodeName, left, right)
s.makeCopy(Array(newLeft, newRight))

case s: Unions if s.childrenResolved &&
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
val newChildren: Seq[LogicalPlan] = widenOutputTypes(s.nodeName, s.children)
s.makeCopy(Array(newChildren))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
* Maps Attributes from the left side to the corresponding Attribute on the right side.
*/
private def buildRewrites(bn: BinaryNode): AttributeMap[Attribute] = {
(bn.isInstanceOf[Intersect] || bn.isInstanceOf[Except])
assert(bn.isInstanceOf[Intersect] || bn.isInstanceOf[Except])
assert(bn.left.output.size == bn.right.output.size)

AttributeMap(bn.left.output.zip(bn.right.output))
Expand Down Expand Up @@ -580,13 +580,15 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
* Combines all adjacent [[Unions]] operators into a single [[Unions]].
*/
object CombineUnions extends Rule[LogicalPlan] {
private def collectUnionChildren(plan: LogicalPlan): Seq[LogicalPlan] = plan match {
case Unions(children) => children.flatMap(collectUnionChildren)
case other => other :: Nil
private def buildUnionChildren(children: Seq[LogicalPlan]): Seq[LogicalPlan] =
children.foldLeft(Seq.empty[LogicalPlan]) { (newChildren, child) => child match {
case Unions(grandchildren) => newChildren ++ grandchildren
case other => newChildren ++ Seq(other)
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case u: Unions => Unions(collectUnionChildren(u))
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case u @ Unions(children) => Unions(buildUnionChildren(children))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
override def output: Seq[Attribute] = left.output
}

/** Factory for constructing new `AppendColumn` nodes. */
object Unions {
/** Factory for constructing new `Unions` nodes. */
object Union {
def apply(left: LogicalPlan, right: LogicalPlan): Unions = {
Unions (left :: right :: Nil)
}
Expand All @@ -131,6 +131,12 @@ case class Unions(children: Seq[LogicalPlan]) extends LogicalPlan {
}
}

override lazy val resolved: Boolean =
childrenResolved &&
children.forall(_.output.length == children.head.output.length) &&
children.forall(_.output.zip(children.head.output).forall {
case (l, r) => l.dataType == r.dataType })

override def statistics: Statistics = {
val sizeInBytes = children.map(_.statistics.sizeInBytes).sum
Statistics(sizeInBytes = sizeInBytes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project, Union, Unions}
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier}
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf}

class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
val conf = new SimpleCatalystConf(true)
Expand Down Expand Up @@ -70,7 +70,7 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
Union(Project(Seq(Alias(left, "l")()), relation),
Project(Seq(Alias(right, "r")()), relation))
val (l, r) = analyzer.execute(plan).collect {
case Union(left, right) => (left.output.head, right.output.head)
case Unions(Seq(child1, child2)) => (child1.output.head, child2.output.head)
}.head
assert(l.dataType === expectedType)
assert(r.dataType === expectedType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,11 @@ class HiveTypeCoercionSuite extends PlanTest {
val wt = HiveTypeCoercion.WidenSetOperationTypes
val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)

val r1 = wt(Union(left, right)).asInstanceOf[Union]
val r1 = wt(Union(left, right)).asInstanceOf[Unions]
val r2 = wt(Except(left, right)).asInstanceOf[Except]
val r3 = wt(Intersect(left, right)).asInstanceOf[Intersect]
checkOutput(r1.left, expectedTypes)
checkOutput(r1.right, expectedTypes)
checkOutput(r1.children.head, expectedTypes)
checkOutput(r1.children.last, expectedTypes)
checkOutput(r2.left, expectedTypes)
checkOutput(r2.right, expectedTypes)
checkOutput(r3.left, expectedTypes)
Expand All @@ -410,12 +410,12 @@ class HiveTypeCoercionSuite extends PlanTest {
AttributeReference("r", DecimalType(5, 5))())
val expectedType1 = Seq(DecimalType(10, 8))

val r1 = dp(Union(left1, right1)).asInstanceOf[Union]
val r1 = dp(Union(left1, right1)).asInstanceOf[Unions]
val r2 = dp(Except(left1, right1)).asInstanceOf[Except]
val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect]

checkOutput(r1.left, expectedType1)
checkOutput(r1.right, expectedType1)
checkOutput(r1.children.head, expectedType1)
checkOutput(r1.children.last, expectedType1)
checkOutput(r2.left, expectedType1)
checkOutput(r2.right, expectedType1)
checkOutput(r3.left, expectedType1)
Expand All @@ -427,23 +427,23 @@ class HiveTypeCoercionSuite extends PlanTest {
val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5),
DecimalType(25, 5), DoubleType, DoubleType)

rightTypes.zip(expectedTypes).map { case (rType, expectedType) =>
rightTypes.zip(expectedTypes).foreach { case (rType, expectedType) =>
val plan2 = LocalRelation(
AttributeReference("r", rType)())

val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union]
val r1 = dp(Union(plan1, plan2)).asInstanceOf[Unions]
val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except]
val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect]

checkOutput(r1.right, Seq(expectedType))
checkOutput(r1.children.last, Seq(expectedType))
checkOutput(r2.right, Seq(expectedType))
checkOutput(r3.right, Seq(expectedType))

val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union]
val r4 = dp(Union(plan2, plan1)).asInstanceOf[Unions]
val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except]
val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect]

checkOutput(r4.left, Seq(expectedType))
checkOutput(r4.children.last, Seq(expectedType))
checkOutput(r5.left, Seq(expectedType))
checkOutput(r6.left, Seq(expectedType))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,46 +38,51 @@ class SetOperationPushDownSuite extends PlanTest {

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
val testUnion = Unions(testRelation, testRelation2)
val testUnion = Union(testRelation, testRelation2)
val testIntersect = Intersect(testRelation, testRelation2)
val testExcept = Except(testRelation, testRelation2)

test("union: combine unions into one unions") {
val unionQuery1 = Unions(Unions(testRelation, testRelation2), testRelation)
val unionQuery2 = Unions(testRelation, Unions(testRelation2, testRelation))
val unionQuery1 = Union(Union(testRelation, testRelation2), testRelation)
val unionQuery2 = Union(testRelation, Union(testRelation2, testRelation))
val unionOptimized1 = Optimize.execute(unionQuery1.analyze)
val unionOptimized2 = Optimize.execute(unionQuery2.analyze)

comparePlans(unionOptimized1, unionOptimized2)

val combinedUnions = Unions(unionOptimized1 :: unionOptimized2 :: Nil)
val combinedUnionsOptimized = Optimize.execute(combinedUnions.analyze)
val unionQuery3 = Unions(unionQuery1, unionQuery2)
val unionQuery3 = Union(unionQuery1, unionQuery2)
val unionOptimized3 = Optimize.execute(unionQuery3.analyze)
comparePlans(combinedUnionsOptimized, unionOptimized3)
}

test("union/intersect/except: filter to each side") {
val unionQuery = testUnion.where('a === 1)
test("intersect/except: filter to each side") {
val intersectQuery = testIntersect.where('b < 10)
val exceptQuery = testExcept.where('c >= 5)

val unionOptimized = Optimize.execute(unionQuery.analyze)
val intersectOptimized = Optimize.execute(intersectQuery.analyze)
val exceptOptimized = Optimize.execute(exceptQuery.analyze)

val unionCorrectAnswer =
Unions(testRelation.where('a === 1) :: testRelation2.where('d === 1) :: Nil).analyze
val intersectCorrectAnswer =
Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze
val exceptCorrectAnswer =
Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze

comparePlans(unionOptimized, unionCorrectAnswer)
comparePlans(intersectOptimized, intersectCorrectAnswer)
comparePlans(exceptOptimized, exceptCorrectAnswer)
}

test("union: project to each side") {
ignore("union: filter to each side") {
val unionQuery = testUnion.where('a === 1)
val unionOptimized = Optimize.execute(unionQuery.analyze)
val unionCorrectAnswer =
Unions(testRelation.where('a === 1) :: testRelation2.where('d === 1) :: Nil).analyze

comparePlans(unionOptimized, unionCorrectAnswer)
}

ignore("union: project to each side") {
val unionQuery = testUnion.select('a)
val unionOptimized = Optimize.execute(unionQuery.analyze)
val unionCorrectAnswer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
def unionAll(other: DataFrame): DataFrame = withPlan {
Unions(logicalPlan, other.logicalPlan)
Union(logicalPlan, other.logicalPlan)
}

/**
Expand Down
4 changes: 3 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,9 @@ class Dataset[T] private[sql](
* duplicate items. As such, it is analogous to `UNION ALL` in SQL.
* @since 1.6.0
*/
def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union)
def union(other: Dataset[T]): Dataset[T] = withPlan[T](other){ (left, right) =>
Unions(left :: right :: Nil)
}

/**
* Returns a new [[Dataset]] where any elements present in `other` have been removed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1211,7 +1211,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
}

// If there are multiple INSERTS just UNION them together into on query.
val query = queries.reduceLeft(Union)
val query =
if (queries.length == 1) {
queries.head
} else {
Unions(queries)
}

// return With plan if there is CTE
cteRelations.map(With(query, _)).getOrElse(query)
Expand Down

0 comments on commit c1f66f7

Please sign in to comment.