Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
imback82 committed Jul 16, 2020
1 parent 794890f commit afa5aca
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2665,6 +2665,15 @@ object SQLConf {
.checkValue(_ > 0, "The difference must be positive.")
.createWithDefault(4)

val BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT =
buildConf("spark.sql.execution.broadcastHashJoin.outputPartitioningExpandLimit")
.doc("The maximum number of partitionings that a HashPartitioning can be expanded to. " +
"This configuration is applicable only for inner joins.")
.version("3.1.0")
.intConf
.checkValue(_ > 0, "The value must be positive.")
.createWithDefault(8)

/**
* Holds information about keys that have been deprecated.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, LongType}

/**
Expand Down Expand Up @@ -64,7 +65,7 @@ case class BroadcastHashJoinExec(

override lazy val outputPartitioning: Partitioning = {
joinType match {
case _: InnerLike =>
case Inner =>
streamedPlan.outputPartitioning match {
case h: HashPartitioning => expandOutputPartitioning(h)
case c: PartitioningCollection => expandOutputPartitioning(c)
Expand Down Expand Up @@ -105,22 +106,36 @@ case class BroadcastHashJoinExec(
// Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
// The expanded expressions are returned as PartitioningCollection.
private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = {
val maxNumCombinations = sqlContext.conf.getConf(
SQLConf.BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT)
var currentNumCombinations = 0

def generateExprCombinations(
current: Seq[Expression],
accumulated: Seq[Expression]): Seq[Seq[Expression]] = {
if (current.isEmpty) {
if (currentNumCombinations > maxNumCombinations) {
Nil
} else if (current.isEmpty) {
currentNumCombinations += 1
Seq(accumulated)
} else {
val buildKeys = streamedKeyToBuildKeyMapping.get(current.head.canonicalized)
generateExprCombinations(current.tail, accumulated :+ current.head) ++
buildKeys.map { _.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b))
}.getOrElse(Nil)
buildKeys.map { bKeys =>
bKeys.flatMap { bKey =>
if (currentNumCombinations < maxNumCombinations) {
generateExprCombinations(current.tail, accumulated :+ bKey)
} else {
Nil
}
}
}.getOrElse(Nil)
}
}

PartitioningCollection(
generateExprCombinations(partitioning.expressions, Nil).map(
HashPartitioning(_, partitioning.numPartitions)))
generateExprCombinations(partitioning.expressions, Nil)
.map(HashPartitioning(_, partitioning.numPartitions)))
}

protected override def doExecute(): RDD[InternalRow] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,34 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
assert(bhj.outputPartitioning === expected)
}

test("BroadcastHashJoinExec output partitioning size should be limited with a config") {
val l1 = AttributeReference("l1", LongType)()
val l2 = AttributeReference("l2", LongType)()
val r1 = AttributeReference("r1", LongType)()
val r2 = AttributeReference("r2", LongType)()

val expected = Seq(
HashPartitioning(Seq(l1, l2), 1),
HashPartitioning(Seq(l1, r2), 1),
HashPartitioning(Seq(r1, l2), 1),
HashPartitioning(Seq(r1, r2), 1))

Seq(1, 2, 3, 4).foreach { limit =>
withSQLConf(
SQLConf.BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT.key -> s"$limit") {
val bhj = BroadcastHashJoinExec(
leftKeys = Seq(l1, l2),
rightKeys = Seq(r1, r2),
Inner,
BuildRight,
None,
left = DummySparkPlan(outputPartitioning = HashPartitioning(Seq(l1, l2), 1)),
right = DummySparkPlan())
assert(bhj.outputPartitioning === PartitioningCollection(expected.take(limit)))
}
}
}

private def expressionsEqual(l: Seq[Expression], r: Seq[Expression]): Boolean = {
l.length == r.length && l.zip(r).forall { case (e1, e2) => e1.semanticEquals(e2) }
}
Expand Down

0 comments on commit afa5aca

Please sign in to comment.