diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9be0497e46603..c1aa3932c6f18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2665,6 +2665,17 @@ object SQLConf { .checkValue(_ > 0, "The difference must be positive.") .createWithDefault(4) + val BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT = + buildConf("spark.sql.execution.broadcastHashJoin.outputPartitioningExpandLimit") + .internal() + .doc("The maximum number of partitionings that a HashPartitioning can be expanded to. " + + "This configuration is applicable only for BroadcastHashJoin inner joins and can be " + + "set to '0' to disable this feature.") + .version("3.1.0") + .intConf + .checkValue(_ >= 0, "The value must be non-negative.") + .createWithDefault(8) + /** * Holds information about keys that have been deprecated. * @@ -2975,6 +2986,9 @@ class SQLConf extends Serializable with Logging { LegacyBehaviorPolicy.withName(getConf(SQLConf.LEGACY_TIME_PARSER_POLICY)) } + def broadcastHashJoinOutputPartitioningExpandLimit: Int = + getConf(BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 707ed1402d1ae..71faad9829a42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.joins +import scala.collection.mutable + import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -26,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnspecifiedDistribution} import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.{BooleanType, LongType} @@ -51,7 +53,7 @@ case class BroadcastHashJoinExec( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override def requiredChildDistribution: Seq[Distribution] = { - val mode = HashedRelationBroadcastMode(buildKeys) + val mode = HashedRelationBroadcastMode(buildBoundKeys) buildSide match { case BuildLeft => BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil @@ -60,6 +62,73 @@ case class BroadcastHashJoinExec( } } + override lazy val outputPartitioning: Partitioning = { + joinType match { + case _: InnerLike if sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 => + streamedPlan.outputPartitioning match { + case h: HashPartitioning => expandOutputPartitioning(h) + case c: PartitioningCollection => expandOutputPartitioning(c) + case other => other + } + case _ => streamedPlan.outputPartitioning + } + } + + // An one-to-many mapping from a streamed key to build keys. + private lazy val streamedKeyToBuildKeyMapping = { + val mapping = mutable.Map.empty[Expression, Seq[Expression]] + streamedKeys.zip(buildKeys).foreach { + case (streamedKey, buildKey) => + val key = streamedKey.canonicalized + mapping.get(key) match { + case Some(v) => mapping.put(key, v :+ buildKey) + case None => mapping.put(key, Seq(buildKey)) + } + } + mapping.toMap + } + + // Expands the given partitioning collection recursively. + private def expandOutputPartitioning( + partitioning: PartitioningCollection): PartitioningCollection = { + PartitioningCollection(partitioning.partitionings.flatMap { + case h: HashPartitioning => expandOutputPartitioning(h).partitionings + case c: PartitioningCollection => Seq(expandOutputPartitioning(c)) + case other => Seq(other) + }) + } + + // Expands the given hash partitioning by substituting streamed keys with build keys. + // For example, if the expressions for the given partitioning are Seq("a", "b", "c") + // where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"), + // the expanded partitioning will have the following expressions: + // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). + // The expanded expressions are returned as PartitioningCollection. + private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = { + val maxNumCombinations = sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit + var currentNumCombinations = 0 + + def generateExprCombinations( + current: Seq[Expression], + accumulated: Seq[Expression]): Seq[Seq[Expression]] = { + if (currentNumCombinations >= maxNumCombinations) { + Nil + } else if (current.isEmpty) { + currentNumCombinations += 1 + Seq(accumulated) + } else { + val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) + generateExprCombinations(current.tail, accumulated :+ current.head) ++ + buildKeysOpt.map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b))) + .getOrElse(Nil) + } + } + + PartitioningCollection( + generateExprCombinations(partitioning.expressions, Nil) + .map(HashPartitioning(_, partitioning.numPartitions))) + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") @@ -135,13 +204,13 @@ case class BroadcastHashJoinExec( ctx: CodegenContext, input: Seq[ExprCode]): (ExprCode, String) = { ctx.currentVars = input - if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { + if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType == LongType) { // generate the join key as Long - val ev = streamedKeys.head.genCode(ctx) + val ev = streamedBoundKeys.head.genCode(ctx) (ev, ev.isNull) } else { // generate the join key as UnsafeRow - val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys) + val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys) (ev, s"${ev.value}.anyNull()") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index c7c3e1672f034..7c3c53b0fa54c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -62,21 +62,30 @@ trait HashJoin extends BaseJoinExec { protected lazy val (buildKeys, streamedKeys) = { require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), "Join keys from two sides should have same types") - val lkeys = bindReferences(HashJoin.rewriteKeyExpr(leftKeys), left.output) - val rkeys = bindReferences(HashJoin.rewriteKeyExpr(rightKeys), right.output) buildSide match { - case BuildLeft => (lkeys, rkeys) - case BuildRight => (rkeys, lkeys) + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) } } + @transient private lazy val (buildOutput, streamedOutput) = { + buildSide match { + case BuildLeft => (left.output, right.output) + case BuildRight => (right.output, left.output) + } + } + + @transient protected lazy val buildBoundKeys = + bindReferences(HashJoin.rewriteKeyExpr(buildKeys), buildOutput) + @transient protected lazy val streamedBoundKeys = + bindReferences(HashJoin.rewriteKeyExpr(streamedKeys), streamedOutput) protected def buildSideKeyGenerator(): Projection = - UnsafeProjection.create(buildKeys) + UnsafeProjection.create(buildBoundKeys) protected def streamSideKeyGenerator(): UnsafeProjection = - UnsafeProjection.create(streamedKeys) + UnsafeProjection.create(streamedBoundKeys) @transient private[this] lazy val boundCondition = if (condition.isDefined) { Predicate.create(condition.get, streamedPlan.output ++ buildPlan.output).eval _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 2b7cd65e7d96f..1120850fdddaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -55,7 +55,8 @@ case class ShuffledHashJoinExec( val buildTime = longMetric("buildTime") val start = System.nanoTime() val context = TaskContext.get() - val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager()) + val relation = HashedRelation( + iter, buildBoundKeys, taskMemoryManager = context.taskMemoryManager()) buildTime += NANOSECONDS.toMillis(System.nanoTime() - start) buildDataSize += relation.estimatedSize // This relation is usually used until the end of task. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index c696d3f648ed1..511e0cf0b3817 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -557,7 +557,8 @@ class AdaptiveQueryExecSuite withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", + SQLConf.BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT.key -> "0") { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM testData " + "join testData2 t2 ON key = t2.a " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index f7d5a899df1c9..7ff945f5cbfb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -21,13 +21,15 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} -import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BitwiseAnd, BitwiseOr, Cast, Expression, Literal, ShiftLeft} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.BROADCAST -import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection} +import org.apache.spark.sql.execution.{DummySparkPlan, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec -import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -415,6 +417,216 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils assert(e.getMessage.contains(s"Could not execute broadcast in $timeout secs.")) } } + + test("broadcast join where streamed side's output partitioning is HashPartitioning") { + withTable("t1", "t3") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { + val df1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1") + val df2 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i2", "j2") + val df3 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i3", "j3") + df1.write.format("parquet").bucketBy(8, "i1", "j1").saveAsTable("t1") + df3.write.format("parquet").bucketBy(8, "i3", "j3").saveAsTable("t3") + val t1 = spark.table("t1") + val t3 = spark.table("t3") + + // join1 is a broadcast join where df2 is broadcasted. Note that output partitioning on the + // streamed side (t1) is HashPartitioning (bucketed files). + val join1 = t1.join(df2, t1("i1") === df2("i2") && t1("j1") === df2("j2")) + val plan1 = join1.queryExecution.executedPlan + assert(collect(plan1) { case e: ShuffleExchangeExec => e }.isEmpty) + val broadcastJoins = collect(plan1) { case b: BroadcastHashJoinExec => b } + assert(broadcastJoins.size == 1) + assert(broadcastJoins(0).outputPartitioning.isInstanceOf[PartitioningCollection]) + val p = broadcastJoins(0).outputPartitioning.asInstanceOf[PartitioningCollection] + assert(p.partitionings.size == 4) + // Verify all the combinations of output partitioning. + Seq(Seq(t1("i1"), t1("j1")), + Seq(t1("i1"), df2("j2")), + Seq(df2("i2"), t1("j1")), + Seq(df2("i2"), df2("j2"))).foreach { expected => + val expectedExpressions = expected.map(_.expr) + assert(p.partitionings.exists { + case h: HashPartitioning => expressionsEqual(h.expressions, expectedExpressions) + }) + } + + // Join on the column from the broadcasted side (i2, j2) and make sure output partitioning + // is maintained by checking no shuffle exchange is introduced. + val join2 = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3")) + val plan2 = join2.queryExecution.executedPlan + assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1) + assert(collect(plan2) { case b: BroadcastHashJoinExec => b }.size == 1) + assert(collect(plan2) { case e: ShuffleExchangeExec => e }.isEmpty) + + // Validate the data with broadcast join off. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3")) + checkAnswer(join2, df) + } + } + } + } + + test("broadcast join where streamed side's output partitioning is PartitioningCollection") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { + val t1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1") + val t2 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i2", "j2") + val t3 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i3", "j3") + val t4 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i4", "j4") + + // join1 is a sort merge join (shuffle on the both sides). + val join1 = t1.join(t2, t1("i1") === t2("i2")) + val plan1 = join1.queryExecution.executedPlan + assert(collect(plan1) { case s: SortMergeJoinExec => s }.size == 1) + assert(collect(plan1) { case e: ShuffleExchangeExec => e }.size == 2) + + // join2 is a broadcast join where t3 is broadcasted. Note that output partitioning on the + // streamed side (join1) is PartitioningCollection (sort merge join) + val join2 = join1.join(t3, join1("i1") === t3("i3")) + val plan2 = join2.queryExecution.executedPlan + assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1) + assert(collect(plan2) { case e: ShuffleExchangeExec => e }.size == 2) + val broadcastJoins = collect(plan2) { case b: BroadcastHashJoinExec => b } + assert(broadcastJoins.size == 1) + assert(broadcastJoins(0).outputPartitioning.isInstanceOf[PartitioningCollection]) + val p = broadcastJoins(0).outputPartitioning.asInstanceOf[PartitioningCollection] + assert(p.partitionings.size == 3) + // Verify all the combinations of output partitioning. + Seq(Seq(t1("i1")), Seq(t2("i2")), Seq(t3("i3"))).foreach { expected => + val expectedExpressions = expected.map(_.expr) + assert(p.partitionings.exists { + case h: HashPartitioning => expressionsEqual(h.expressions, expectedExpressions) + }) + } + + // Join on the column from the broadcasted side (i3) and make sure output partitioning + // is maintained by checking no shuffle exchange is introduced. Note that one extra + // ShuffleExchangeExec is from t4, not from join2. + val join3 = join2.join(t4, join2("i3") === t4("i4")) + val plan3 = join3.queryExecution.executedPlan + assert(collect(plan3) { case s: SortMergeJoinExec => s }.size == 2) + assert(collect(plan3) { case b: BroadcastHashJoinExec => b }.size == 1) + assert(collect(plan3) { case e: ShuffleExchangeExec => e }.size == 3) + + // Validate the data with broadcast join off. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df = join2.join(t4, join2("i3") === t4("i4")) + checkAnswer(join3, df) + } + } + } + + test("BroadcastHashJoinExec output partitioning scenarios for inner join") { + val l1 = AttributeReference("l1", LongType)() + val l2 = AttributeReference("l2", LongType)() + val l3 = AttributeReference("l3", LongType)() + val r1 = AttributeReference("r1", LongType)() + val r2 = AttributeReference("r2", LongType)() + val r3 = AttributeReference("r3", LongType)() + + // Streamed side has a HashPartitioning. + var bhj = BroadcastHashJoinExec( + leftKeys = Seq(l2, l3), + rightKeys = Seq(r1, r2), + Inner, + BuildRight, + None, + left = DummySparkPlan(outputPartitioning = HashPartitioning(Seq(l1, l2, l3), 1)), + right = DummySparkPlan()) + var expected = PartitioningCollection(Seq( + HashPartitioning(Seq(l1, l2, l3), 1), + HashPartitioning(Seq(l1, l2, r2), 1), + HashPartitioning(Seq(l1, r1, l3), 1), + HashPartitioning(Seq(l1, r1, r2), 1))) + assert(bhj.outputPartitioning === expected) + + // Streamed side has a PartitioningCollection. + bhj = BroadcastHashJoinExec( + leftKeys = Seq(l1, l2, l3), + rightKeys = Seq(r1, r2, r3), + Inner, + BuildRight, + None, + left = DummySparkPlan(outputPartitioning = PartitioningCollection(Seq( + HashPartitioning(Seq(l1, l2), 1), HashPartitioning(Seq(l3), 1)))), + right = DummySparkPlan()) + expected = PartitioningCollection(Seq( + HashPartitioning(Seq(l1, l2), 1), + HashPartitioning(Seq(l1, r2), 1), + HashPartitioning(Seq(r1, l2), 1), + HashPartitioning(Seq(r1, r2), 1), + HashPartitioning(Seq(l3), 1), + HashPartitioning(Seq(r3), 1))) + assert(bhj.outputPartitioning === expected) + + // Streamed side has a nested PartitioningCollection. + bhj = BroadcastHashJoinExec( + leftKeys = Seq(l1, l2, l3), + rightKeys = Seq(r1, r2, r3), + Inner, + BuildRight, + None, + left = DummySparkPlan(outputPartitioning = PartitioningCollection(Seq( + PartitioningCollection(Seq(HashPartitioning(Seq(l1), 1), HashPartitioning(Seq(l2), 1))), + HashPartitioning(Seq(l3), 1)))), + right = DummySparkPlan()) + expected = PartitioningCollection(Seq( + PartitioningCollection(Seq( + HashPartitioning(Seq(l1), 1), + HashPartitioning(Seq(r1), 1), + HashPartitioning(Seq(l2), 1), + HashPartitioning(Seq(r2), 1))), + HashPartitioning(Seq(l3), 1), + HashPartitioning(Seq(r3), 1))) + assert(bhj.outputPartitioning === expected) + + // One-to-mapping case ("l1" = "r1" AND "l1" = "r2") + bhj = BroadcastHashJoinExec( + leftKeys = Seq(l1, l1), + rightKeys = Seq(r1, r2), + Inner, + BuildRight, + None, + left = DummySparkPlan(outputPartitioning = HashPartitioning(Seq(l1, l2), 1)), + right = DummySparkPlan()) + expected = PartitioningCollection(Seq( + HashPartitioning(Seq(l1, l2), 1), + HashPartitioning(Seq(r1, l2), 1), + HashPartitioning(Seq(r2, l2), 1))) + assert(bhj.outputPartitioning === expected) + } + + test("BroadcastHashJoinExec output partitioning size should be limited with a config") { + val l1 = AttributeReference("l1", LongType)() + val l2 = AttributeReference("l2", LongType)() + val r1 = AttributeReference("r1", LongType)() + val r2 = AttributeReference("r2", LongType)() + + val expected = Seq( + HashPartitioning(Seq(l1, l2), 1), + HashPartitioning(Seq(l1, r2), 1), + HashPartitioning(Seq(r1, l2), 1), + HashPartitioning(Seq(r1, r2), 1)) + + Seq(1, 2, 3, 4).foreach { limit => + withSQLConf( + SQLConf.BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT.key -> s"$limit") { + val bhj = BroadcastHashJoinExec( + leftKeys = Seq(l1, l2), + rightKeys = Seq(r1, r2), + Inner, + BuildRight, + None, + left = DummySparkPlan(outputPartitioning = HashPartitioning(Seq(l1, l2), 1)), + right = DummySparkPlan()) + assert(bhj.outputPartitioning === PartitioningCollection(expected.take(limit))) + } + } + } + + private def expressionsEqual(l: Seq[Expression], r: Seq[Expression]): Boolean = { + l.length == r.length && l.zip(r).forall { case (e1, e2) => e1.semanticEquals(e2) } + } } class BroadcastJoinSuite extends BroadcastJoinSuiteBase with DisableAdaptiveExecutionSuite