Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-31869][SQL] BroadcastHashJoinExec can utilize the build side for its output partitioning #28676

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.joins

import scala.collection.mutable

import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
Expand All @@ -26,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.{BooleanType, LongType}
Expand Down Expand Up @@ -60,6 +62,92 @@ case class BroadcastHashJoinExec(
}
}

override def outputPartitioning: Partitioning = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val or lazy val?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it to lazy val

val (buildKeys, streamedKeys) = buildSide match {
case BuildLeft => (leftKeys, rightKeys)
case BuildRight => (rightKeys, leftKeys)
}

joinType match {
case _: InnerLike =>
streamedPlan.outputPartitioning match {
case h: HashPartitioning =>
getBuildSidePartitioning(h, streamedKeys, buildKeys) match {
case Some(p) => PartitioningCollection(Seq(h, p))
case None => h
}
case c: PartitioningCollection =>
c.partitionings.foreach {
case h: HashPartitioning =>
getBuildSidePartitioning(h, streamedKeys, buildKeys) match {
case Some(p) => return PartitioningCollection(c.partitionings :+ p)
case None => ()
}
case _ => ()
}
c
case other => other
}
case _ => streamedPlan.outputPartitioning
}
}

/**
* Returns a partitioning for the build side if the following conditions are met:
* - The streamed side's output partitioning expressions consist of all the keys
* from the streamed side, we can add a partitioning for the build side.
* - There is a one-to-one mapping from streamed keys to build keys.
*
* The build side partitioning will have expressions in the same order as the expressions
* in the streamed side partitioning. For example, for the following setup:
* - streamed partitioning expressions: Seq(s1, s2)
* - streamed keys: Seq(c1, c2)
* - build keys: Seq(b1, b2)
* the expressions in the build side partitioning will be Seq(b1, b2), not Seq(b2, b1).
*/
private def getBuildSidePartitioning(
streamedPartitioning: HashPartitioning,
streamedKeys: Seq[Expression],
buildKeys: Seq[Expression]): Option[HashPartitioning] = {
if (!satisfiesPartitioning(streamedKeys, streamedPartitioning)) {
return None
}

val streamedKeyToBuildKeyMap = mutable.Map.empty[Expression, Expression]
streamedKeys.zip(buildKeys).foreach {
case (streamedKey, buildKey) =>
val inserted = streamedKeyToBuildKeyMap.getOrElseUpdate(
streamedKey.canonicalized,
buildKey)

if (!inserted.semanticEquals(buildKey)) {
// One-to-many mapping from streamed keys to build keys found.
return None
}
}

// Ensure the one-to-one mapping from streamed keys to build keys.
if (streamedKeyToBuildKeyMap.size != streamedKeyToBuildKeyMap.values.toSet.size) {
return None
}

// The final expressions are built by mapping stream partitioning expressions ->
// streamed keys -> build keys.
val buildPartitioningExpressions = streamedPartitioning.expressions.map { e =>
streamedKeyToBuildKeyMap(e.canonicalized)
}

Some(HashPartitioning(buildPartitioningExpressions, streamedPartitioning.numPartitions))
}

// Returns true if `keys` consist of all the expressions in `partitioning`.
private def satisfiesPartitioning(
keys: Seq[Expression],
partitioning: HashPartitioning): Boolean = {
partitioning.expressions.length == keys.length &&
partitioning.expressions.forall(e => keys.exists(_.semanticEquals(e)))
}

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ class AdaptiveQueryExecSuite
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 2)
val smj2 = findTopLevelSortMergeJoin(adaptivePlan)
assert(smj2.size == 2, origPlan.toString)
assert(smj2.size == 1, origPlan.toString)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans.logical.BROADCAST
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection}
import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
Expand Down Expand Up @@ -415,6 +416,95 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
assert(e.getMessage.contains(s"Could not execute broadcast in $timeout secs."))
}
}

test("broadcast join where streamed side's output partitioning is PartitioningCollection") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
val t1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1")
val t2 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i2", "j2")
val t3 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i3", "j3")
val t4 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i4", "j4")

// join1 is a sort merge join (shuffle on the both sides).
val join1 = t1.join(t2, t1("i1") === t2("i2"))
val plan1 = join1.queryExecution.executedPlan
assert(collect(plan1) { case s: SortMergeJoinExec => s }.size == 1)
assert(collect(plan1) { case e: ShuffleExchangeExec => e }.size == 2)

// join2 is a broadcast join where t3 is broadcasted. Note that output partitioning on the
// streamed side (join1) is PartitioningCollection (sort merge join)
val join2 = join1.join(t3, join1("i1") === t3("i3"))
val plan2 = join2.queryExecution.executedPlan
assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1)
assert(collect(plan2) { case e: ShuffleExchangeExec => e }.size == 2)
val broadcastJoins = collect(plan2) { case b: BroadcastHashJoinExec => b }
assert(broadcastJoins.size == 1)
broadcastJoins(0).outputPartitioning match {
case p: PartitioningCollection
if p.partitionings.forall(_.isInstanceOf[HashPartitioning]) =>
// two partitionings from sort merge join and one from build side.
assert(p.partitionings.size == 3)
case _ => fail()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: For bretter test error messages,

assert(broadcastJoins(0).outputPartitioning.isInstanceOf[PartitioningCollection]))
val p = broadcastJoins(0).outputPartitioning.asInstanceOf[PartitioningCollection])
...

Or, could you add error messages in fail()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as suggested.

}

// Join on the column from the broadcasted side (i3) and make sure output partitioning
// is maintained by checking no shuffle exchange is introduced. Note that one extra
// ShuffleExchangeExec is from t4, not from join2.
val join3 = join2.join(t4, join2("i3") === t4("i4"))
val plan3 = join3.queryExecution.executedPlan
assert(collect(plan3) { case s: SortMergeJoinExec => s }.size == 2)
assert(collect(plan3) { case b: BroadcastHashJoinExec => b }.size == 1)
assert(collect(plan3) { case e: ShuffleExchangeExec => e }.size == 3)

// Validate the data with boradcast join off.
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df = join2.join(t4, join2("i3") === t4("i4"))
QueryTest.sameRows(join3.collect().toSeq, df.collect().toSeq)
}
}
}

test("broadcast join where streamed side's output partitioning is HashPartitioning") {
withTable("t1", "t3") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
val df1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1")
val df2 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i2", "j2")
val df3 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i3", "j3")
df1.write.format("parquet").bucketBy(8, "i1", "j1").saveAsTable("t1")
df3.write.format("parquet").bucketBy(8, "i3", "j3").saveAsTable("t3")
val t1 = spark.table("t1")
val t3 = spark.table("t3")

// join1 is a broadcast join where df2 is broadcasted. Note that output partitioning on the
// streamed side (t1) is HashPartitioning (bucketed files).
val join1 = t1.join(df2, t1("i1") === df2("i2") && t1("j1") === df2("j2"))
val plan1 = join1.queryExecution.executedPlan
assert(collect(plan1) { case e: ShuffleExchangeExec => e }.isEmpty)
val broadcastJoins = collect(plan1) { case b: BroadcastHashJoinExec => b }
assert(broadcastJoins.size == 1)
broadcastJoins(0).outputPartitioning match {
case p: PartitioningCollection
if p.partitionings.forall(_.isInstanceOf[HashPartitioning]) =>
// one partitioning from streamed side and one from build side.
assert(p.partitionings.size == 2)
case _ => fail()
}

// Join on the column from the broadcasted side (i2, j2) and make sure output partitioning
// is maintained by checking no shuffle exchange is introduced.
val join2 = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3"))
val plan2 = join2.queryExecution.executedPlan
assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1)
assert(collect(plan2) { case b: BroadcastHashJoinExec => b }.size == 1)
assert(collect(plan2) { case e: ShuffleExchangeExec => e }.isEmpty)

// Validate the data with broadcast join off.
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3"))
QueryTest.sameRows(join2.collect().toSeq, df.collect().toSeq)
}
}
}
}
}

class BroadcastJoinSuite extends BroadcastJoinSuiteBase with DisableAdaptiveExecutionSuite
Expand Down