diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b2e0a5b2390c2..07fb260a77ea0 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1500,6 +1500,28 @@ def intersect(self, other): """ return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + @since(2.4) + def intersectAll(self, other): + """ Return a new :class:`DataFrame` containing rows in both this dataframe and other + dataframe while preserving duplicates. + + This is equivalent to `INTERSECT ALL` in SQL. + >>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3), ("c", 4)], ["C1", "C2"]) + >>> df2 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3)], ["C1", "C2"]) + + >>> df1.intersectAll(df2).sort("C1", "C2").show() + +---+---+ + | C1| C2| + +---+---+ + | a| 1| + | a| 1| + | b| 3| + +---+---+ + + Also as standard in SQL, this function resolves columns by position (not by name). + """ + return DataFrame(self._jdf.intersectAll(other._jdf), self.sql_ctx) + @since(1.3) def subtract(self, other): """ Return a new :class:`DataFrame` containing rows in this frame diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8abb1c70d4919..9965cd654bcb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -914,7 +914,7 @@ class Analyzer( // To resolve duplicate expression IDs for Join and Intersect case j @ Join(left, right, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) - case i @ Intersect(left, right) if !i.duplicateResolved => + case i @ Intersect(left, right, _) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) case e @ Except(left, right, _) if !e.duplicateResolved => e.copy(right = dedupRight(left, right)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index f9edca53d571e..7dd26b62b1fc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -325,11 +325,11 @@ object TypeCoercion { assert(newChildren.length == 2) Except(newChildren.head, newChildren.last, isAll) - case s @ Intersect(left, right) if s.childrenResolved && + case s @ Intersect(left, right, isAll) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) assert(newChildren.length == 2) - Intersect(newChildren.head, newChildren.last) + Intersect(newChildren.head, newChildren.last, isAll) case s: Union if s.childrenResolved && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index c9a3ee47a02be..cff4cee09427f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -309,7 +309,7 @@ object UnsupportedOperationChecker { case Except(left, right, _) if right.isStreaming => throwError("Except on a streaming DataFrame/Dataset on the right is not supported") - case Intersect(left, right) if left.isStreaming && right.isStreaming => + case Intersect(left, right, _) if left.isStreaming && right.isStreaming => throwError("Intersect between two streaming DataFrames/Datasets is not supported") case GroupingSets(_, _, child, _) if child.isStreaming => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 193f6591c9a8b..105623c767d66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -136,6 +136,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, RewriteExcepAll, + RewriteIntersectAll, ReplaceIntersectWithSemiJoin, ReplaceExceptWithFilter, ReplaceExceptWithAntiJoin, @@ -1402,7 +1403,7 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { */ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Intersect(left, right) => + case Intersect(left, right, false) => assert(left.output.size == right.output.size) val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And))) @@ -1488,6 +1489,84 @@ object RewriteExcepAll extends Rule[LogicalPlan] { } } +/** + * Replaces logical [[Intersect]] operator using a combination of Union, Aggregate + * and Generate operator. + * + * Input Query : + * {{{ + * SELECT c1 FROM ut1 INTERSECT ALL SELECT c1 FROM ut2 + * }}} + * + * Rewritten Query: + * {{{ + * SELECT c1 + * FROM ( + * SELECT replicate_row(min_count, c1) + * FROM ( + * SELECT c1, If (vcol1_cnt > vcol2_cnt, vcol2_cnt, vcol1_cnt) AS min_count + * FROM ( + * SELECT c1, count(vcol1) as vcol1_cnt, count(vcol2) as vcol2_cnt + * FROM ( + * SELECT true as vcol1, null as , c1 FROM ut1 + * UNION ALL + * SELECT null as vcol1, true as vcol2, c1 FROM ut2 + * ) AS union_all + * GROUP BY c1 + * HAVING vcol1_cnt >= 1 AND vcol2_cnt >= 1 + * ) + * ) + * ) + * }}} + */ +object RewriteIntersectAll extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Intersect(left, right, true) => + assert(left.output.size == right.output.size) + + val trueVcol1 = Alias(Literal(true), "vcol1")() + val nullVcol1 = Alias(Literal(null, BooleanType), "vcol1")() + + val trueVcol2 = Alias(Literal(true), "vcol2")() + val nullVcol2 = Alias(Literal(null, BooleanType), "vcol2")() + + // Add a projection on the top of left and right plans to project out + // the additional virtual columns. + val leftPlanWithAddedVirtualCols = Project(Seq(trueVcol1, nullVcol2) ++ left.output, left) + val rightPlanWithAddedVirtualCols = Project(Seq(nullVcol1, trueVcol2) ++ right.output, right) + + val unionPlan = Union(leftPlanWithAddedVirtualCols, rightPlanWithAddedVirtualCols) + + // Expressions to compute count and minimum of both the counts. + val vCol1AggrExpr = + Alias(AggregateExpression(Count(unionPlan.output(0)), Complete, false), "vcol1_count")() + val vCol2AggrExpr = + Alias(AggregateExpression(Count(unionPlan.output(1)), Complete, false), "vcol2_count")() + val ifExpression = Alias(If( + GreaterThan(vCol1AggrExpr.toAttribute, vCol2AggrExpr.toAttribute), + vCol2AggrExpr.toAttribute, + vCol1AggrExpr.toAttribute + ), "min_count")() + + val aggregatePlan = Aggregate(left.output, + Seq(vCol1AggrExpr, vCol2AggrExpr) ++ left.output, unionPlan) + val filterPlan = Filter(And(GreaterThanOrEqual(vCol1AggrExpr.toAttribute, Literal(1L)), + GreaterThanOrEqual(vCol2AggrExpr.toAttribute, Literal(1L))), aggregatePlan) + val projectMinPlan = Project(left.output ++ Seq(ifExpression), filterPlan) + + // Apply the replicator to replicate rows based on min_count + val genRowPlan = Generate( + ReplicateRows(Seq(ifExpression.toAttribute) ++ left.output), + unrequiredChildIndex = Nil, + outer = false, + qualifier = None, + left.output, + projectMinPlan + ) + Project(left.output, genRowPlan) + } +} + /** * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result * but only makes the grouping key bigger. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8b3c0686181fd..8a8db6df37094 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -533,7 +533,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.UNION => Distinct(Union(left, right)) case SqlBaseParser.INTERSECT if all => - throw new ParseException("INTERSECT ALL is not supported.", ctx) + Intersect(left, right, isAll = true) case SqlBaseParser.INTERSECT => Intersect(left, right) case SqlBaseParser.EXCEPT if all => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 498a13a62bd22..13b51304d7f89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -164,7 +164,12 @@ object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) } -case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { +case class Intersect( + left: LogicalPlan, + right: LogicalPlan, + isAll: Boolean = false) extends SetOperation(left, right) { + + override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" ) override def output: Seq[Attribute] = left.output.zip(right.output).map { case (leftAttr, rightAttr) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index f002aa3aacaba..cb744be400603 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, GreaterThan, Literal, ReplicateRows} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, GreaterThan, GreaterThanOrEqual, If, Literal, ReplicateRows} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.BooleanType class SetOperationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -166,4 +167,33 @@ class SetOperationSuite extends PlanTest { )) comparePlans(expectedPlan, rewrittenPlan) } + + test("INTERSECT ALL rewrite") { + val input = Intersect(testRelation, testRelation2, isAll = true) + val rewrittenPlan = RewriteIntersectAll(input) + val leftRelation = testRelation + .select(Literal(true).as("vcol1"), Literal(null, BooleanType).as("vcol2"), 'a, 'b, 'c) + val rightRelation = testRelation2 + .select(Literal(null, BooleanType).as("vcol1"), Literal(true).as("vcol2"), 'd, 'e, 'f) + val planFragment = leftRelation.union(rightRelation) + .groupBy('a, 'b, 'c)(count('vcol1).as("vcol1_count"), + count('vcol2).as("vcol2_count"), 'a, 'b, 'c) + .where(And(GreaterThanOrEqual('vcol1_count, Literal(1L)), + GreaterThanOrEqual('vcol2_count, Literal(1L)))) + .select('a, 'b, 'c, + If(GreaterThan('vcol1_count, 'vcol2_count), 'vcol2_count, 'vcol1_count).as("min_count")) + .analyze + val multiplerAttr = planFragment.output.last + val output = planFragment.output.dropRight(1) + val expectedPlan = Project(output, + Generate( + ReplicateRows(Seq(multiplerAttr) ++ output), + Nil, + false, + None, + output, + planFragment + )) + comparePlans(expectedPlan, rewrittenPlan) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 629e3c4f3fcfb..9be0ec5af78ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -70,7 +70,6 @@ class PlanParserSuite extends AnalysisTest { intercept("select * from a minus all select * from b", "MINUS ALL is not supported.") assertEqual("select * from a minus distinct select * from b", a.except(b)) assertEqual("select * from a intersect select * from b", a.intersect(b)) - intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.") assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e6a3b0adcdaa6..d36c8d13acca9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1934,6 +1934,23 @@ class Dataset[T] private[sql]( Intersect(planWithBarrier, other.planWithBarrier) } + /** + * Returns a new Dataset containing rows only in both this Dataset and another Dataset while + * preserving the duplicates. + * This is equivalent to `INTERSECT ALL` in SQL. + * + * @note Equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. Also as standard + * in SQL, this function resolves columns by position (not by name). + * + * @group typedrel + * @since 2.4.0 + */ + def intersectAll(other: Dataset[T]): Dataset[T] = withSetOperator { + Intersect(logicalPlan, other.logicalPlan, isAll = true) + } + + /** * Returns a new Dataset containing rows in this Dataset but not in another Dataset. * This is equivalent to `EXCEPT DISTINCT` in SQL. @@ -1961,7 +1978,7 @@ class Dataset[T] private[sql]( * @since 2.4.0 */ def exceptAll(other: Dataset[T]): Dataset[T] = withSetOperator { - Except(planWithBarrier, other.planWithBarrier, isAll = true) + Except(logicalPlan, other.logicalPlan, isAll = true) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3f5fd3dbb9e2f..75eff8a88312b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -529,9 +529,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Distinct(child) => throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") - case logical.Intersect(left, right) => + case logical.Intersect(left, right, false) => throw new IllegalStateException( - "logical intersect operator should have been replaced by semi-join in the optimizer") + "logical intersect operator should have been replaced by semi-join in the optimizer") + case logical.Intersect(left, right, true) => + throw new IllegalStateException( + "logical intersect operator should have been replaced by union, aggregate" + + "and generate operators in the optimizer") case logical.Except(left, right, false) => throw new IllegalStateException( "logical except operator should have been replaced by anti-join in the optimizer") diff --git a/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql new file mode 100644 index 0000000000000..ff4395c3e7447 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql @@ -0,0 +1,123 @@ +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (1, 3), + (2, 3), + (null, null), + (null, null) + AS tab1(k, v); +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (2, 3), + (3, 4), + (null, null), + (null, null) + AS tab2(k, v); + +-- Basic INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2; + +-- INTERSECT ALL same table in both branches +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab1 WHERE k = 1; + +-- Empty left relation +SELECT * FROM tab1 WHERE k > 2 +INTERSECT ALL +SELECT * FROM tab2; + +-- Empty right relation +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 WHERE k > 3; + +-- Type Coerced INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT CAST(1 AS BIGINT), CAST(2 AS BIGINT); + +-- Error as types of two side are not compatible +SELECT * FROM tab1 +INTERSECT ALL +SELECT array(1), 2; + +-- Mismatch on number of columns across both branches +SELECT k FROM tab1 +INTERSECT ALL +SELECT k, v FROM tab2; + +-- Basic +SELECT * FROM tab2 +INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2; + +-- Chain of different `set operations +-- We need to parenthesize the following two queries to enforce +-- certain order of evaluation of operators. After fix to +-- SPARK-24966 this can be removed. +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +( +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +); + +-- Chain of different `set operations +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +EXCEPT +( +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +); + +-- Join under intersect all +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k); + +-- Join under intersect all (2) +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab2.v AS k, + tab1.k AS v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k); + +-- Group by under intersect all +SELECT v FROM tab1 GROUP BY v +INTERSECT ALL +SELECT k FROM tab2 GROUP BY k; + +-- Clean-up +DROP VIEW IF EXISTS tab1; +DROP VIEW IF EXISTS tab2; diff --git a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out new file mode 100644 index 0000000000000..792791bc51628 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out @@ -0,0 +1,241 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 17 + + +-- !query 0 +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (1, 3), + (2, 3), + (null, null), + (null, null) + AS tab1(k, v) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (2, 3), + (3, 4), + (null, null), + (null, null) + AS tab2(k, v) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 2 schema +struct +-- !query 2 output +1 2 +1 2 +2 3 +NULL NULL +NULL NULL + + +-- !query 3 +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab1 WHERE k = 1 +-- !query 3 schema +struct +-- !query 3 output +1 2 +1 2 +1 3 +1 3 + + +-- !query 4 +SELECT * FROM tab1 WHERE k > 2 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 WHERE k > 3 +-- !query 5 schema +struct +-- !query 5 output + + + +-- !query 6 +SELECT * FROM tab1 +INTERSECT ALL +SELECT CAST(1 AS BIGINT), CAST(2 AS BIGINT) +-- !query 6 schema +struct +-- !query 6 output +1 2 + + +-- !query 7 +SELECT * FROM tab1 +INTERSECT ALL +SELECT array(1), 2 +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +IntersectAll can only be performed on tables with the compatible column types. array <> int at the first column of the second table; + + +-- !query 8 +SELECT k FROM tab1 +INTERSECT ALL +SELECT k, v FROM tab2 +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +IntersectAll can only be performed on tables with the same number of columns, but the first table has 1 columns and the second table has 2 columns; + + +-- !query 9 +SELECT * FROM tab2 +INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 9 schema +struct +-- !query 9 output +1 2 +1 2 +2 3 +NULL NULL +NULL NULL + + +-- !query 10 +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +( +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +) +-- !query 10 schema +struct +-- !query 10 output +1 2 +1 2 +1 3 +2 3 +NULL NULL +NULL NULL + + +-- !query 11 +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +EXCEPT +( +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +) +-- !query 11 schema +struct +-- !query 11 output +1 3 + + +-- !query 12 +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +-- !query 12 schema +struct +-- !query 12 output +1 2 +1 2 +1 2 +1 2 +1 2 +1 2 +1 2 +1 2 +2 3 + + +-- !query 13 +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab2.v AS k, + tab1.k AS v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +-- !query 13 schema +struct +-- !query 13 output + + + +-- !query 14 +SELECT v FROM tab1 GROUP BY v +INTERSECT ALL +SELECT k FROM tab2 GROUP BY k +-- !query 14 schema +struct +-- !query 14 output +2 +3 +NULL + + +-- !query 15 +DROP VIEW IF EXISTS tab1 +-- !query 15 schema +struct<> +-- !query 15 output + + + +-- !query 16 +DROP VIEW IF EXISTS tab2 +-- !query 16 schema +struct<> +-- !query 16 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index af0735920cc29..b0e22a51e7611 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -749,6 +749,60 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df4.schema.forall(!_.nullable)) } + test("intersectAll") { + checkAnswer( + lowerCaseDataWithDuplicates.intersectAll(lowerCaseDataWithDuplicates), + Row(1, "a") :: + Row(2, "b") :: + Row(2, "b") :: + Row(3, "c") :: + Row(3, "c") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.intersectAll(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersectAll(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // Duplicate nulls are preserved. + checkAnswer( + allNulls.intersectAll(allNulls), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) + + val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") + val df_right = Seq(1, 2, 2, 3).toDF("id") + + checkAnswer( + df_left.intersectAll(df_right), + Row(1) :: Row(2) :: Row(2) :: Row(3) :: Nil) + } + + test("intersectAll - nullability") { + val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.intersectAll(nullInts) + checkAnswer(df1, Row(1) :: Row(3) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.intersectAll(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(3) :: Nil) + assert(df2.schema.forall(!_.nullable)) + + val df3 = nullInts.intersectAll(nullInts) + checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.intersectAll(nonNullableInts) + checkAnswer(df4, Row(1) :: Row(3) :: Nil) + assert(df4.schema.forall(!_.nullable)) + } + test("udf") { val foo = udf((a: Int, b: String) => a.toString + b) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 0cfe260e52152..deea9dbb30aae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -136,6 +136,19 @@ private[sql] trait SQLTestData { self => df } + protected lazy val lowerCaseDataWithDuplicates: DataFrame = { + val df = spark.sparkContext.parallelize( + LowerCaseData(1, "a") :: + LowerCaseData(2, "b") :: + LowerCaseData(2, "b") :: + LowerCaseData(3, "c") :: + LowerCaseData(3, "c") :: + LowerCaseData(3, "c") :: + LowerCaseData(4, "d") :: Nil).toDF() + df.createOrReplaceTempView("lowerCaseData") + df + } + protected lazy val arrayData: RDD[ArrayData] = { val rdd = spark.sparkContext.parallelize( ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::