Skip to content

Commit

Permalink
[SPARK-21274][SQL] Implement INTERSECT ALL clause
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
Implements INTERSECT ALL clause through query rewrites using existing operators in Spark.  Please refer to [Link](https://drive.google.com/open?id=1nyW0T0b_ajUduQoPgZLAsyHK8s3_dko3ulQuxaLpUXE) for the design.

Input Query
``` SQL
SELECT c1 FROM ut1 INTERSECT ALL SELECT c1 FROM ut2
```
Rewritten Query
```SQL
   SELECT c1
    FROM (
         SELECT replicate_row(min_count, c1)
         FROM (
              SELECT c1,
                     IF (vcol1_cnt > vcol2_cnt, vcol2_cnt, vcol1_cnt) AS min_count
              FROM (
                   SELECT   c1, count(vcol1) as vcol1_cnt, count(vcol2) as vcol2_cnt
                   FROM (
                        SELECT c1, true as vcol1, null as vcol2 FROM ut1
                        UNION ALL
                        SELECT c1, null as vcol1, true as vcol2 FROM ut2
                        ) AS union_all
                   GROUP BY c1
                   HAVING vcol1_cnt >= 1 AND vcol2_cnt >= 1
                  )
              )
          )
```

## How was this patch tested?
Added test cases in SQLQueryTestSuite, DataFrameSuite, SetOperationSuite

Author: Dilip Biswal <[email protected]>

Closes #21886 from dilipbiswal/dkb_intersect_all_final.
  • Loading branch information
dilipbiswal authored and gatorsmile committed Jul 30, 2018
1 parent 6690924 commit 65a4bc1
Show file tree
Hide file tree
Showing 15 changed files with 599 additions and 12 deletions.
22 changes: 22 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,28 @@ def intersect(self, other):
"""
return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)

@since(2.4)
def intersectAll(self, other):
""" Return a new :class:`DataFrame` containing rows in both this dataframe and other
dataframe while preserving duplicates.
This is equivalent to `INTERSECT ALL` in SQL.
>>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3), ("c", 4)], ["C1", "C2"])
>>> df2 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3)], ["C1", "C2"])
>>> df1.intersectAll(df2).sort("C1", "C2").show()
+---+---+
| C1| C2|
+---+---+
| a| 1|
| a| 1|
| b| 3|
+---+---+
Also as standard in SQL, this function resolves columns by position (not by name).
"""
return DataFrame(self._jdf.intersectAll(other._jdf), self.sql_ctx)

@since(1.3)
def subtract(self, other):
""" Return a new :class:`DataFrame` containing rows in this frame
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ class Analyzer(
// To resolve duplicate expression IDs for Join and Intersect
case j @ Join(left, right, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
case i @ Intersect(left, right) if !i.duplicateResolved =>
case i @ Intersect(left, right, _) if !i.duplicateResolved =>
i.copy(right = dedupRight(left, right))
case e @ Except(left, right, _) if !e.duplicateResolved =>
e.copy(right = dedupRight(left, right))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,11 @@ object TypeCoercion {
assert(newChildren.length == 2)
Except(newChildren.head, newChildren.last, isAll)

case s @ Intersect(left, right) if s.childrenResolved &&
case s @ Intersect(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
assert(newChildren.length == 2)
Intersect(newChildren.head, newChildren.last)
Intersect(newChildren.head, newChildren.last, isAll)

case s: Union if s.childrenResolved &&
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ object UnsupportedOperationChecker {
case Except(left, right, _) if right.isStreaming =>
throwError("Except on a streaming DataFrame/Dataset on the right is not supported")

case Intersect(left, right) if left.isStreaming && right.isStreaming =>
case Intersect(left, right, _) if left.isStreaming && right.isStreaming =>
throwError("Intersect between two streaming DataFrames/Datasets is not supported")

case GroupingSets(_, _, child, _) if child.isStreaming =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
OptimizeSubqueries) ::
Batch("Replace Operators", fixedPoint,
RewriteExcepAll,
RewriteIntersectAll,
ReplaceIntersectWithSemiJoin,
ReplaceExceptWithFilter,
ReplaceExceptWithAntiJoin,
Expand Down Expand Up @@ -1402,7 +1403,7 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
*/
object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Intersect(left, right) =>
case Intersect(left, right, false) =>
assert(left.output.size == right.output.size)
val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) }
Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And)))
Expand Down Expand Up @@ -1488,6 +1489,84 @@ object RewriteExcepAll extends Rule[LogicalPlan] {
}
}

/**
* Replaces logical [[Intersect]] operator using a combination of Union, Aggregate
* and Generate operator.
*
* Input Query :
* {{{
* SELECT c1 FROM ut1 INTERSECT ALL SELECT c1 FROM ut2
* }}}
*
* Rewritten Query:
* {{{
* SELECT c1
* FROM (
* SELECT replicate_row(min_count, c1)
* FROM (
* SELECT c1, If (vcol1_cnt > vcol2_cnt, vcol2_cnt, vcol1_cnt) AS min_count
* FROM (
* SELECT c1, count(vcol1) as vcol1_cnt, count(vcol2) as vcol2_cnt
* FROM (
* SELECT true as vcol1, null as , c1 FROM ut1
* UNION ALL
* SELECT null as vcol1, true as vcol2, c1 FROM ut2
* ) AS union_all
* GROUP BY c1
* HAVING vcol1_cnt >= 1 AND vcol2_cnt >= 1
* )
* )
* )
* }}}
*/
object RewriteIntersectAll extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Intersect(left, right, true) =>
assert(left.output.size == right.output.size)

val trueVcol1 = Alias(Literal(true), "vcol1")()
val nullVcol1 = Alias(Literal(null, BooleanType), "vcol1")()

val trueVcol2 = Alias(Literal(true), "vcol2")()
val nullVcol2 = Alias(Literal(null, BooleanType), "vcol2")()

// Add a projection on the top of left and right plans to project out
// the additional virtual columns.
val leftPlanWithAddedVirtualCols = Project(Seq(trueVcol1, nullVcol2) ++ left.output, left)
val rightPlanWithAddedVirtualCols = Project(Seq(nullVcol1, trueVcol2) ++ right.output, right)

val unionPlan = Union(leftPlanWithAddedVirtualCols, rightPlanWithAddedVirtualCols)

// Expressions to compute count and minimum of both the counts.
val vCol1AggrExpr =
Alias(AggregateExpression(Count(unionPlan.output(0)), Complete, false), "vcol1_count")()
val vCol2AggrExpr =
Alias(AggregateExpression(Count(unionPlan.output(1)), Complete, false), "vcol2_count")()
val ifExpression = Alias(If(
GreaterThan(vCol1AggrExpr.toAttribute, vCol2AggrExpr.toAttribute),
vCol2AggrExpr.toAttribute,
vCol1AggrExpr.toAttribute
), "min_count")()

val aggregatePlan = Aggregate(left.output,
Seq(vCol1AggrExpr, vCol2AggrExpr) ++ left.output, unionPlan)
val filterPlan = Filter(And(GreaterThanOrEqual(vCol1AggrExpr.toAttribute, Literal(1L)),
GreaterThanOrEqual(vCol2AggrExpr.toAttribute, Literal(1L))), aggregatePlan)
val projectMinPlan = Project(left.output ++ Seq(ifExpression), filterPlan)

// Apply the replicator to replicate rows based on min_count
val genRowPlan = Generate(
ReplicateRows(Seq(ifExpression.toAttribute) ++ left.output),
unrequiredChildIndex = Nil,
outer = false,
qualifier = None,
left.output,
projectMinPlan
)
Project(left.output, genRowPlan)
}
}

/**
* Removes literals from group expressions in [[Aggregate]], as they have no effect to the result
* but only makes the grouping key bigger.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case SqlBaseParser.UNION =>
Distinct(Union(left, right))
case SqlBaseParser.INTERSECT if all =>
throw new ParseException("INTERSECT ALL is not supported.", ctx)
Intersect(left, right, isAll = true)
case SqlBaseParser.INTERSECT =>
Intersect(left, right)
case SqlBaseParser.EXCEPT if all =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,12 @@ object SetOperation {
def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right))
}

case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
case class Intersect(
left: LogicalPlan,
right: LogicalPlan,
isAll: Boolean = false) extends SetOperation(left, right) {

override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" )

override def output: Seq[Attribute] =
left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, GreaterThan, Literal, ReplicateRows}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, GreaterThan, GreaterThanOrEqual, If, Literal, ReplicateRows}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.BooleanType

class SetOperationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
Expand Down Expand Up @@ -166,4 +167,33 @@ class SetOperationSuite extends PlanTest {
))
comparePlans(expectedPlan, rewrittenPlan)
}

test("INTERSECT ALL rewrite") {
val input = Intersect(testRelation, testRelation2, isAll = true)
val rewrittenPlan = RewriteIntersectAll(input)
val leftRelation = testRelation
.select(Literal(true).as("vcol1"), Literal(null, BooleanType).as("vcol2"), 'a, 'b, 'c)
val rightRelation = testRelation2
.select(Literal(null, BooleanType).as("vcol1"), Literal(true).as("vcol2"), 'd, 'e, 'f)
val planFragment = leftRelation.union(rightRelation)
.groupBy('a, 'b, 'c)(count('vcol1).as("vcol1_count"),
count('vcol2).as("vcol2_count"), 'a, 'b, 'c)
.where(And(GreaterThanOrEqual('vcol1_count, Literal(1L)),
GreaterThanOrEqual('vcol2_count, Literal(1L))))
.select('a, 'b, 'c,
If(GreaterThan('vcol1_count, 'vcol2_count), 'vcol2_count, 'vcol1_count).as("min_count"))
.analyze
val multiplerAttr = planFragment.output.last
val output = planFragment.output.dropRight(1)
val expectedPlan = Project(output,
Generate(
ReplicateRows(Seq(multiplerAttr) ++ output),
Nil,
false,
None,
output,
planFragment
))
comparePlans(expectedPlan, rewrittenPlan)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ class PlanParserSuite extends AnalysisTest {
intercept("select * from a minus all select * from b", "MINUS ALL is not supported.")
assertEqual("select * from a minus distinct select * from b", a.except(b))
assertEqual("select * from a intersect select * from b", a.intersect(b))
intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.")
assertEqual("select * from a intersect distinct select * from b", a.intersect(b))
}

Expand Down
19 changes: 18 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1934,6 +1934,23 @@ class Dataset[T] private[sql](
Intersect(planWithBarrier, other.planWithBarrier)
}

/**
* Returns a new Dataset containing rows only in both this Dataset and another Dataset while
* preserving the duplicates.
* This is equivalent to `INTERSECT ALL` in SQL.
*
* @note Equality checking is performed directly on the encoded representation of the data
* and thus is not affected by a custom `equals` function defined on `T`. Also as standard
* in SQL, this function resolves columns by position (not by name).
*
* @group typedrel
* @since 2.4.0
*/
def intersectAll(other: Dataset[T]): Dataset[T] = withSetOperator {
Intersect(logicalPlan, other.logicalPlan, isAll = true)
}


/**
* Returns a new Dataset containing rows in this Dataset but not in another Dataset.
* This is equivalent to `EXCEPT DISTINCT` in SQL.
Expand Down Expand Up @@ -1961,7 +1978,7 @@ class Dataset[T] private[sql](
* @since 2.4.0
*/
def exceptAll(other: Dataset[T]): Dataset[T] = withSetOperator {
Except(planWithBarrier, other.planWithBarrier, isAll = true)
Except(logicalPlan, other.logicalPlan, isAll = true)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,9 +529,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Distinct(child) =>
throw new IllegalStateException(
"logical distinct operator should have been replaced by aggregate in the optimizer")
case logical.Intersect(left, right) =>
case logical.Intersect(left, right, false) =>
throw new IllegalStateException(
"logical intersect operator should have been replaced by semi-join in the optimizer")
"logical intersect operator should have been replaced by semi-join in the optimizer")
case logical.Intersect(left, right, true) =>
throw new IllegalStateException(
"logical intersect operator should have been replaced by union, aggregate" +
"and generate operators in the optimizer")
case logical.Except(left, right, false) =>
throw new IllegalStateException(
"logical except operator should have been replaced by anti-join in the optimizer")
Expand Down
Loading

0 comments on commit 65a4bc1

Please sign in to comment.