diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index cc1a5e835d9cd..0d64493170c25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -92,8 +92,13 @@ case class ClusteredDistribution( } /** - * Represents data where tuples have been clustered according to the hash of the given - * `expressions`. The hash function is defined as `HashPartitioning.partitionIdExpression`, so only + * If exceptNull == false: Represents data where tuples have been clustered according to the hash of + * the given `expressions`. + * If exceptNull == true: Represents data where tuples have been clustered according to the hash of + * the given `expressions` except NULL, it means NULL can distribute in any partitions. This is + * often used in conditions of Join, where NULL's distribution is not cared about due to NULL will + * be considered not equal to any value + * The hash function is defined as `HashPartitioning.partitionIdExpression`, so only * [[HashPartitioning]] can satisfy this distribution. * * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the @@ -101,7 +106,8 @@ case class ClusteredDistribution( */ case class HashClusteredDistribution( expressions: Seq[Expression], - requiredNumPartitions: Option[Int] = None) extends Distribution { + requiredNumPartitions: Option[Int] = None, + exceptNull: Boolean = false) extends Distribution { require( expressions != Nil, "The expressions for hash of a HashClusteredDistribution should not be Nil. " + @@ -112,7 +118,7 @@ case class HashClusteredDistribution( assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + s"the actual number of partitions is $numPartitions.") - HashPartitioning(expressions, numPartitions) + HashPartitioning(expressions, numPartitions, exceptNull) } } @@ -207,12 +213,16 @@ case object SinglePartition extends Partitioning { } /** - * Represents a partitioning where rows are split up across partitions based on the hash - * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be + * If exceptNull == false: Represents a partitioning where rows are split up across partitions based + * on the hash of `expressions`. + * If exceptNull == true: Represents a partitioning where rows are split up across partitions based + * on the hash of `expressions` except null, which is the only key not co-partitioned. + * All rows where `expressions` evaluate to the same values are guaranteed to be * in the same partition. */ -case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) - extends Expression with Partitioning with Unevaluable { +case class HashPartitioning( + expressions: Seq[Expression], numPartitions: Int, exceptNull: Boolean = false) + extends Expression with Partitioning with Unevaluable { override def children: Seq[Expression] = expressions override def nullable: Boolean = false @@ -222,8 +232,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) super.satisfies0(required) || { required match { case h: HashClusteredDistribution => - expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { - case (l, r) => l.semanticEquals(r) + expressions.length == h.expressions.length && (h.exceptNull || !exceptNull) && + expressions.zip(h.expressions).forall { + case (l, r) => l.semanticEquals(r) } case ClusteredDistribution(requiredClustering, _) => expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index ffb2abf7d9d0f..d179bb19b809f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, - SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf /** @@ -163,13 +162,13 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { leftPartitioning match { - case HashPartitioning(leftExpressions, _) + case HashPartitioning(leftExpressions, _, _) if leftExpressions.length == leftKeys.length && leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => reorder(leftKeys, rightKeys, leftExpressions, leftKeys) case _ => rightPartitioning match { - case HashPartitioning(rightExpressions, _) + case HashPartitioning(rightExpressions, _, _) if rightExpressions.length == rightKeys.length && rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => reorder(leftKeys, rightKeys, rightExpressions, rightKeys) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index bbd1a3f005d74..1c254887dfc39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -190,7 +190,7 @@ object ShuffleExchangeExec { serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = { val part: Partitioner = newPartitioning match { case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) - case HashPartitioning(_, n) => + case HashPartitioning(_, n, _) => new Partitioner { override def numPartitions: Int = n // For HashPartitioning, the partitioning key is already a valid partition ID, as we use diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 2de2f30eb05d3..992d15075f965 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -70,15 +70,27 @@ case class SortMergeJoinExec( // For left and right outer joins, the output is partitioned by the streamed input's join keys. case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case FullOuter => + // The output of Full Outer Join is similar to pure HashPartioning, except for NULL, which + // is the only key not co-partitioned + (left.outputPartitioning, right.outputPartitioning) match { + case (l: HashPartitioning, r: HashPartitioning) => + PartitioningCollection(Seq(l.copy(exceptNull = true), r.copy(exceptNull = true))) + case _ => UnknownPartitioning(left.outputPartitioning.numPartitions) + } case LeftExistence(_) => left.outputPartitioning case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") } - override def requiredChildDistribution: Seq[Distribution] = - HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil + override def requiredChildDistribution: Seq[Distribution] = joinType match { + case Inner | LeftOuter | RightOuter | FullOuter => + HashClusteredDistribution(leftKeys, exceptNull = true) :: + HashClusteredDistribution(rightKeys, exceptNull = true) :: Nil + case _ => HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil + } + override def outputOrdering: Seq[SortOrder] = joinType match { // For inner join, orders of both sides keys should be kept. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 7e96a6c541760..50cbf2370d924 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,21 +18,21 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row} +import org.apache.spark.sql.{QueryTest, Row, execution} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -class PlannerSuite extends SharedSQLContext { +class PlannerSuite extends QueryTest with SharedSQLContext { import testImplicits._ setupTestData() @@ -683,6 +683,41 @@ class PlannerSuite extends SharedSQLContext { case _ => fail() } } + test("EnsureRequirements doesn't add shuffle between 2 successive full outer joins on the same " + + "key") { + val df1 = spark.range(1, 100, 1, 2).filter(_ % 2 == 0).selectExpr("id as a1") + val df2 = spark.range(1, 100, 1, 2).selectExpr("id as b2") + val df3 = spark.range(1, 100, 1, 2).selectExpr("id as a3") + val fullOuterJoins = df1 + .join(df2, col("a1") === col("b2"), "full_outer") + .join(df3, col("a1") === col("a3"), "full_outer") + assert( + fullOuterJoins.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => e } + .length === 3) + val expected = (1 until 100).filter(_ % 2 == 0).map(i => Row(i, i, i)) ++ + (1 until 100).filterNot(_ % 2 == 0).map(Row(null, _, null)) ++ + (1 until 100).filterNot(_ % 2 == 0).map(Row(null, null, _)) + checkAnswer(fullOuterJoins, expected) + } + + test("EnsureRequirements still adds shuffle for non-successive full outer joins on the same key") + { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df1 = spark.range(1, 100).selectExpr("id as a1") + val df2 = spark.range(1, 100).selectExpr("id as b2") + val df3 = spark.range(1, 100).selectExpr("id as a3") + val df4 = spark.range(1, 100).selectExpr("id as a4") + + val fullOuterJoins = df1 + .join(df2, col("a1") === col("b2"), "full_outer") + .join(df3, col("a1") === col("a3"), "left_outer") + .join(df4, col("a3") === col("a4"), "full_outer") + fullOuterJoins.explain(true) + assert( + fullOuterJoins.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => e } + .length === 5) + } + } } // Used for unit-testing EnsureRequirements