Skip to content

Commit

Permalink
[SPARK-48613][SQL] SPJ: Support auto-shuffle one side + less join key…
Browse files Browse the repository at this point in the history
…s than partition keys

### What changes were proposed in this pull request?

This is the final planned SPJ scenario:  auto-shuffle one side + less join keys than partition keys.  Background:

- Auto-shuffle works by creating ShuffleExchange for the non-partitioned side, with a clone of the partitioned side's KeyGroupedPartitioning.
- "Less join key than partition key" works by 'projecting' all partition values by join keys (ie, keeping only partition columns that are join columns).  It makes a target KeyGroupedShuffleSpec with 'projected' partition values, and then pushes this down to BatchScanExec.  The BatchScanExec then 'groups' its projected partition value (except in the skew case but that's a different story..).

This combination is hard because the SPJ planning calls is spread in several places in this scenario.  Given two sides, a non-partitioned side and a partitioned side, and the join keys are only a subset:

1.  EnsureRequirements creates the target KeyGroupedShuffleSpec from the join's required distribution (ie, using only the join keys, not all partition keys).
2.  EnsureRequirements copies this to the non-partitoned side's KeyGroupedPartition (for the auto-shuffle case)
3.  BatchScanExec groups the partitions (for the partitioned side), including by join keys (if they differ from partition keys).

Take the example partition columns (id, name) , and partition values: (1, "bob"), (2, "alice"), (2, "sam").
Projection leaves us (1, 2, 2), and the final grouped partition values are (1, 2).

The problem is, that the two sides of the join do not match at all times.  After the steps 1 and 2, the partitioned side has the 'projected' partition values (1, 2, 2), and the non-partitioned side creates a matching KeyGroupedPartitioning (1, 2, 2) for ShuffleExechange.  But on step 3, the BatchScanExec for partitioned side groups the partitions to become (1, 2), but the non-partitioned side does not group and still retains (1, 2, 2) partitions.  This leads to following assert error from the join:

```
requirement failed: PartitioningCollection requires all of its partitionings have the same numPartitions.
java.lang.IllegalArgumentException: requirement failed: PartitioningCollection requires all of its partitionings have the same numPartitions.
	at scala.Predef$.require(Predef.scala:337)
	at org.apache.spark.sql.catalyst.plans.physical.PartitioningCollection.<init>(partitioning.scala:550)
	at org.apache.spark.sql.execution.joins.ShuffledJoin.outputPartitioning(ShuffledJoin.scala:49)
	at org.apache.spark.sql.execution.joins.ShuffledJoin.outputPartitioning$(ShuffledJoin.scala:47)
	at org.apache.spark.sql.execution.joins.SortMergeJoinExec.outputPartitioning(SortMergeJoinExec.scala:39)
	at org.apache.spark.sql.execution.exchange.EnsureRequirements.$anonfun$ensureDistributionAndOrdering$1(EnsureRequirements.scala:66)
	at scala.collection.immutable.Vector1.map(Vector.scala:2140)
	at scala.collection.immutable.Vector1.map(Vector.scala:385)
	at org.apache.spark.sql.execution.exchange.EnsureRequirements.org$apache$spark$sql$execution$exchange$EnsureRequirements$$ensureDistributionAndOrdering(EnsureRequirements.scala:65)
	at org.apache.spark.sql.execution.exchange.EnsureRequirements$$anonfun$1.applyOrElse(EnsureRequirements.scala:657)
	at org.apache.spark.sql.execution.exchange.EnsureRequirements$$anonfun$1.applyOrElse(EnsureRequirements.scala:632)
```

The fix is to do the de-duplication in first pass.

1. Pushing down join keys to the BatchScanExec to return a de-duped outputPartitioning (partitioned side)
2. Creating the non-partitioned side's KeyGroupedPartitioning with de-duped partition keys (non-partitioned side).

  ### Why are the changes needed?

This is the last planned scenario for SPJ not yet supported.

  ### How was this patch tested?
Update existing unit test in KeyGroupedPartitionSuite

  ### Was this patch authored or co-authored using generative AI tooling?
 No.

Closes apache#47064 from szehon-ho/spj_less_join_key_auto_shuffle.

Authored-by: Szehon Ho <[email protected]>
Signed-off-by: Chao Sun <[email protected]>
  • Loading branch information
szehon-ho authored and jingz-db committed Jul 22, 2024
1 parent 6c41d75 commit b80ee03
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,13 @@ object KeyGroupedPartitioning {
val projectedOriginalPartitionValues =
originalPartitionValues.map(project(expressions, projectionPositions, _))

KeyGroupedPartitioning(projectedExpressions, projectedPartitionValues.length,
projectedPartitionValues, projectedOriginalPartitionValues)
val finalPartitionValues = projectedPartitionValues
.map(InternalRowComparableWrapper(_, projectedExpressions))
.distinct
.map(_.row)

KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length,
finalPartitionValues, projectedOriginalPartitionValues)
}

def project(
Expand Down Expand Up @@ -871,20 +876,12 @@ case class KeyGroupedShuffleSpec(
if (results.forall(p => p.isEmpty)) None else Some(results)
}

override def canCreatePartitioning: Boolean = {
// Allow one side shuffle for SPJ for now only if partially-clustered is not enabled
// and for join keys less than partition keys only if transforms are not enabled.
val checkExprType = if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
e: Expression => e.isInstanceOf[AttributeReference]
} else {
e: Expression => e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression]
}
override def canCreatePartitioning: Boolean =
SQLConf.get.v2BucketingShuffleEnabled &&
!SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
partitioning.expressions.forall(checkExprType)
}


partitioning.expressions.forall { e =>
e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression]
}

override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
val newExpressions: Seq[Expression] = clustering.zip(partitioning.expressions).map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,29 @@ case class BatchScanExec(

override def outputPartitioning: Partitioning = {
super.outputPartitioning match {
case k: KeyGroupedPartitioning if spjParams.commonPartitionValues.isDefined =>
// We allow duplicated partition values if
// `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true
val newPartValues = spjParams.commonPartitionValues.get.flatMap {
case (partValue, numSplits) => Seq.fill(numSplits)(partValue)
}
case k: KeyGroupedPartitioning =>
val expressions = spjParams.joinKeyPositions match {
case Some(projectionPositions) => projectionPositions.map(i => k.expressions(i))
case _ => k.expressions
}

val newPartValues = spjParams.commonPartitionValues match {
case Some(commonPartValues) =>
// We allow duplicated partition values if
// `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true
commonPartValues.flatMap {
case (partValue, numSplits) => Seq.fill(numSplits)(partValue)
}
case None =>
spjParams.joinKeyPositions match {
case Some(projectionPositions) => k.partitionValues.map{r =>
val projectedRow = KeyGroupedPartitioning.project(expressions,
projectionPositions, r)
InternalRowComparableWrapper(projectedRow, expressions)
}.distinct.map(_.row)
case _ => k.partitionValues
}
}
k.copy(expressions = expressions, numPartitions = newPartValues.length,
partitionValues = newPartValues)
case p => p
Expand Down Expand Up @@ -279,7 +292,8 @@ case class StoragePartitionJoinParams(
case other: StoragePartitionJoinParams =>
this.commonPartitionValues == other.commonPartitionValues &&
this.replicatePartitions == other.replicatePartitions &&
this.applyPartialClustering == other.applyPartialClustering
this.applyPartialClustering == other.applyPartialClustering &&
this.joinKeyPositions == other.joinKeyPositions
case _ =>
false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,16 @@ case class EnsureRequirements(
child
case ((child, dist), idx) =>
if (bestSpecOpt.isDefined && bestSpecOpt.get.isCompatibleWith(specs(idx))) {
child
bestSpecOpt match {
// If keyGroupCompatible = false, we can still perform SPJ
// by shuffling the other side based on join keys (see the else case below).
// Hence we need to ensure that after this call, the outputPartitioning of the
// partitioned side's BatchScanExec is grouped by join keys to match,
// and we do that by pushing down the join keys
case Some(KeyGroupedShuffleSpec(_, _, Some(joinKeyPositions))) =>
populateJoinKeyPositions(child, Some(joinKeyPositions))
case _ => child
}
} else {
val newPartitioning = bestSpecOpt.map { bestSpec =>
// Use the best spec to create a new partitioning to re-shuffle this child
Expand Down Expand Up @@ -578,6 +587,21 @@ case class EnsureRequirements(
child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions))
}


private def populateJoinKeyPositions(
plan: SparkPlan,
joinKeyPositions: Option[Seq[Int]]): SparkPlan = plan match {
case scan: BatchScanExec =>
scan.copy(
spjParams = scan.spjParams.copy(
joinKeyPositions = joinKeyPositions
)
)
case node =>
node.mapChildren(child => populateJoinKeyPositions(
child, joinKeyPositions))
}

private def reduceCommonPartValues(
commonPartValues: Seq[(InternalRow, Int)],
expressions: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2168,8 +2168,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {
val df = createJoinTestDF(Seq("id" -> "item_id"))
val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.size == 2, "SPJ should not be triggered for transform expression with" +
"less join keys than partition keys for now.")
assert(shuffles.size == 1, "SPJ should be triggered")
checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0),
Row(1, "aa", 30.0, 89.0),
Row(1, "aa", 40.0, 42.0),
Expand Down

0 comments on commit b80ee03

Please sign in to comment.