Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-31869][SQL] BroadcastHashJoinExec can utilize the build side for its output partitioning #28676

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2665,6 +2665,17 @@ object SQLConf {
.checkValue(_ > 0, "The difference must be positive.")
.createWithDefault(4)

val BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT =
buildConf("spark.sql.execution.broadcastHashJoin.outputPartitioningExpandLimit")
.internal()
.doc("The maximum number of partitionings that a HashPartitioning can be expanded to. " +
"This configuration is applicable only for inner joins and can be set to '0' to disable " +
"this feature.")
.version("3.1.0")
.intConf
.checkValue(_ >= 0, "The value must be non-negative.")
.createWithDefault(8)

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -2975,6 +2986,9 @@ class SQLConf extends Serializable with Logging {
LegacyBehaviorPolicy.withName(getConf(SQLConf.LEGACY_TIME_PARSER_POLICY))
}

def broadcastHashJoinOutputPartitioningExpandLimit: Int =
getConf(BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT)

/**
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
* identifiers are equal.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.joins

import scala.collection.mutable

import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
Expand All @@ -26,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.{BooleanType, LongType}
Expand All @@ -51,7 +53,7 @@ case class BroadcastHashJoinExec(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

override def requiredChildDistribution: Seq[Distribution] = {
val mode = HashedRelationBroadcastMode(buildKeys)
val mode = HashedRelationBroadcastMode(buildBoundKeys)
buildSide match {
case BuildLeft =>
BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
Expand All @@ -60,6 +62,73 @@ case class BroadcastHashJoinExec(
}
}

override lazy val outputPartitioning: Partitioning = {
joinType match {
case _: InnerLike if sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 =>
streamedPlan.outputPartitioning match {
case h: HashPartitioning => expandOutputPartitioning(h)
case c: PartitioningCollection => expandOutputPartitioning(c)
case other => other
}
case _ => streamedPlan.outputPartitioning
}
}

// An one-to-many mapping from a streamed key to build keys.
private lazy val streamedKeyToBuildKeyMapping = {
val mapping = mutable.Map.empty[Expression, Seq[Expression]]
streamedKeys.zip(buildKeys).foreach {
case (streamedKey, buildKey) =>
val key = streamedKey.canonicalized
mapping.get(key) match {
case Some(v) => mapping.put(key, v :+ buildKey)
case None => mapping.put(key, Seq(buildKey))
}
}
mapping.toMap
}

// Expands the given partitioning collection recursively.
private def expandOutputPartitioning(
partitioning: PartitioningCollection): PartitioningCollection = {
PartitioningCollection(partitioning.partitionings.flatMap {
case h: HashPartitioning => expandOutputPartitioning(h).partitionings
case c: PartitioningCollection => Seq(expandOutputPartitioning(c))
case other => Seq(other)
})
}

// Expands the given hash partitioning by substituting streamed keys with build keys.
// For example, if the expressions for the given partitioning are Seq("a", "b", "c")
// where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"),
// the expanded partitioning will have the following expressions:
// Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
// The expanded expressions are returned as PartitioningCollection.
private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = {
val maxNumCombinations = sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit
var currentNumCombinations = 0

def generateExprCombinations(
current: Seq[Expression],
accumulated: Seq[Expression]): Seq[Seq[Expression]] = {
if (currentNumCombinations >= maxNumCombinations) {
Nil
} else if (current.isEmpty) {
currentNumCombinations += 1
Seq(accumulated)
} else {
val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized)
generateExprCombinations(current.tail, accumulated :+ current.head) ++
buildKeysOpt.map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b)))
.getOrElse(Nil)
}
}

PartitioningCollection(
generateExprCombinations(partitioning.expressions, Nil)
.map(HashPartitioning(_, partitioning.numPartitions)))
}

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")

Expand Down Expand Up @@ -135,13 +204,13 @@ case class BroadcastHashJoinExec(
ctx: CodegenContext,
input: Seq[ExprCode]): (ExprCode, String) = {
ctx.currentVars = input
if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) {
if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType == LongType) {
// generate the join key as Long
val ev = streamedKeys.head.genCode(ctx)
val ev = streamedBoundKeys.head.genCode(ctx)
(ev, ev.isNull)
} else {
// generate the join key as UnsafeRow
val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys)
val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
(ev, s"${ev.value}.anyNull()")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,30 @@ trait HashJoin extends BaseJoinExec {
protected lazy val (buildKeys, streamedKeys) = {
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
"Join keys from two sides should have same types")
val lkeys = bindReferences(HashJoin.rewriteKeyExpr(leftKeys), left.output)
val rkeys = bindReferences(HashJoin.rewriteKeyExpr(rightKeys), right.output)
buildSide match {
case BuildLeft => (lkeys, rkeys)
case BuildRight => (rkeys, lkeys)
case BuildLeft => (leftKeys, rightKeys)
case BuildRight => (rightKeys, leftKeys)
}
}

@transient private lazy val (buildOutput, streamedOutput) = {
buildSide match {
case BuildLeft => (left.output, right.output)
case BuildRight => (right.output, left.output)
}
}

@transient protected lazy val buildBoundKeys =
bindReferences(HashJoin.rewriteKeyExpr(buildKeys), buildOutput)

@transient protected lazy val streamedBoundKeys =
bindReferences(HashJoin.rewriteKeyExpr(streamedKeys), streamedOutput)

protected def buildSideKeyGenerator(): Projection =
UnsafeProjection.create(buildKeys)
UnsafeProjection.create(buildBoundKeys)

protected def streamSideKeyGenerator(): UnsafeProjection =
UnsafeProjection.create(streamedKeys)
UnsafeProjection.create(streamedBoundKeys)

@transient private[this] lazy val boundCondition = if (condition.isDefined) {
Predicate.create(condition.get, streamedPlan.output ++ buildPlan.output).eval _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ case class ShuffledHashJoinExec(
val buildTime = longMetric("buildTime")
val start = System.nanoTime()
val context = TaskContext.get()
val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
val relation = HashedRelation(
iter, buildBoundKeys, taskMemoryManager = context.taskMemoryManager())
buildTime += NANOSECONDS.toMillis(System.nanoTime() - start)
buildDataSize += relation.estimatedSize
// This relation is usually used until the end of task.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,8 @@ class AdaptiveQueryExecSuite

withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
SQLConf.BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT.key -> "0") {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT * FROM testData " +
"join testData2 t2 ON key = t2.a " +
Expand Down
Loading