From b80ee039082bd6b06a92b5af8c476285910421f0 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Sun, 14 Jul 2024 16:53:32 -0700 Subject: [PATCH] [SPARK-48613][SQL] SPJ: Support auto-shuffle one side + less join keys than partition keys ### What changes were proposed in this pull request? This is the final planned SPJ scenario: auto-shuffle one side + less join keys than partition keys. Background: - Auto-shuffle works by creating ShuffleExchange for the non-partitioned side, with a clone of the partitioned side's KeyGroupedPartitioning. - "Less join key than partition key" works by 'projecting' all partition values by join keys (ie, keeping only partition columns that are join columns). It makes a target KeyGroupedShuffleSpec with 'projected' partition values, and then pushes this down to BatchScanExec. The BatchScanExec then 'groups' its projected partition value (except in the skew case but that's a different story..). This combination is hard because the SPJ planning calls is spread in several places in this scenario. Given two sides, a non-partitioned side and a partitioned side, and the join keys are only a subset: 1. EnsureRequirements creates the target KeyGroupedShuffleSpec from the join's required distribution (ie, using only the join keys, not all partition keys). 2. EnsureRequirements copies this to the non-partitoned side's KeyGroupedPartition (for the auto-shuffle case) 3. BatchScanExec groups the partitions (for the partitioned side), including by join keys (if they differ from partition keys). Take the example partition columns (id, name) , and partition values: (1, "bob"), (2, "alice"), (2, "sam"). Projection leaves us (1, 2, 2), and the final grouped partition values are (1, 2). The problem is, that the two sides of the join do not match at all times. After the steps 1 and 2, the partitioned side has the 'projected' partition values (1, 2, 2), and the non-partitioned side creates a matching KeyGroupedPartitioning (1, 2, 2) for ShuffleExechange. But on step 3, the BatchScanExec for partitioned side groups the partitions to become (1, 2), but the non-partitioned side does not group and still retains (1, 2, 2) partitions. This leads to following assert error from the join: ``` requirement failed: PartitioningCollection requires all of its partitionings have the same numPartitions. java.lang.IllegalArgumentException: requirement failed: PartitioningCollection requires all of its partitionings have the same numPartitions. at scala.Predef$.require(Predef.scala:337) at org.apache.spark.sql.catalyst.plans.physical.PartitioningCollection.(partitioning.scala:550) at org.apache.spark.sql.execution.joins.ShuffledJoin.outputPartitioning(ShuffledJoin.scala:49) at org.apache.spark.sql.execution.joins.ShuffledJoin.outputPartitioning$(ShuffledJoin.scala:47) at org.apache.spark.sql.execution.joins.SortMergeJoinExec.outputPartitioning(SortMergeJoinExec.scala:39) at org.apache.spark.sql.execution.exchange.EnsureRequirements.$anonfun$ensureDistributionAndOrdering$1(EnsureRequirements.scala:66) at scala.collection.immutable.Vector1.map(Vector.scala:2140) at scala.collection.immutable.Vector1.map(Vector.scala:385) at org.apache.spark.sql.execution.exchange.EnsureRequirements.org$apache$spark$sql$execution$exchange$EnsureRequirements$$ensureDistributionAndOrdering(EnsureRequirements.scala:65) at org.apache.spark.sql.execution.exchange.EnsureRequirements$$anonfun$1.applyOrElse(EnsureRequirements.scala:657) at org.apache.spark.sql.execution.exchange.EnsureRequirements$$anonfun$1.applyOrElse(EnsureRequirements.scala:632) ``` The fix is to do the de-duplication in first pass. 1. Pushing down join keys to the BatchScanExec to return a de-duped outputPartitioning (partitioned side) 2. Creating the non-partitioned side's KeyGroupedPartitioning with de-duped partition keys (non-partitioned side). ### Why are the changes needed? This is the last planned scenario for SPJ not yet supported. ### How was this patch tested? Update existing unit test in KeyGroupedPartitionSuite ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47064 from szehon-ho/spj_less_join_key_auto_shuffle. Authored-by: Szehon Ho Signed-off-by: Chao Sun --- .../plans/physical/partitioning.scala | 25 ++++++++--------- .../datasources/v2/BatchScanExec.scala | 28 ++++++++++++++----- .../exchange/EnsureRequirements.scala | 26 ++++++++++++++++- .../KeyGroupedPartitioningSuite.scala | 3 +- 4 files changed, 58 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 19595eef10b34..f8e980747bf2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -434,8 +434,13 @@ object KeyGroupedPartitioning { val projectedOriginalPartitionValues = originalPartitionValues.map(project(expressions, projectionPositions, _)) - KeyGroupedPartitioning(projectedExpressions, projectedPartitionValues.length, - projectedPartitionValues, projectedOriginalPartitionValues) + val finalPartitionValues = projectedPartitionValues + .map(InternalRowComparableWrapper(_, projectedExpressions)) + .distinct + .map(_.row) + + KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length, + finalPartitionValues, projectedOriginalPartitionValues) } def project( @@ -871,20 +876,12 @@ case class KeyGroupedShuffleSpec( if (results.forall(p => p.isEmpty)) None else Some(results) } - override def canCreatePartitioning: Boolean = { - // Allow one side shuffle for SPJ for now only if partially-clustered is not enabled - // and for join keys less than partition keys only if transforms are not enabled. - val checkExprType = if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { - e: Expression => e.isInstanceOf[AttributeReference] - } else { - e: Expression => e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression] - } + override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled && - partitioning.expressions.forall(checkExprType) - } - - + partitioning.expressions.forall { e => + e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression] + } override def createPartitioning(clustering: Seq[Expression]): Partitioning = { val newExpressions: Seq[Expression] = clustering.zip(partitioning.expressions).map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index f949dbf71a371..997576a396d20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -118,16 +118,29 @@ case class BatchScanExec( override def outputPartitioning: Partitioning = { super.outputPartitioning match { - case k: KeyGroupedPartitioning if spjParams.commonPartitionValues.isDefined => - // We allow duplicated partition values if - // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true - val newPartValues = spjParams.commonPartitionValues.get.flatMap { - case (partValue, numSplits) => Seq.fill(numSplits)(partValue) - } + case k: KeyGroupedPartitioning => val expressions = spjParams.joinKeyPositions match { case Some(projectionPositions) => projectionPositions.map(i => k.expressions(i)) case _ => k.expressions } + + val newPartValues = spjParams.commonPartitionValues match { + case Some(commonPartValues) => + // We allow duplicated partition values if + // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true + commonPartValues.flatMap { + case (partValue, numSplits) => Seq.fill(numSplits)(partValue) + } + case None => + spjParams.joinKeyPositions match { + case Some(projectionPositions) => k.partitionValues.map{r => + val projectedRow = KeyGroupedPartitioning.project(expressions, + projectionPositions, r) + InternalRowComparableWrapper(projectedRow, expressions) + }.distinct.map(_.row) + case _ => k.partitionValues + } + } k.copy(expressions = expressions, numPartitions = newPartValues.length, partitionValues = newPartValues) case p => p @@ -279,7 +292,8 @@ case class StoragePartitionJoinParams( case other: StoragePartitionJoinParams => this.commonPartitionValues == other.commonPartitionValues && this.replicatePartitions == other.replicatePartitions && - this.applyPartialClustering == other.applyPartialClustering + this.applyPartialClustering == other.applyPartialClustering && + this.joinKeyPositions == other.joinKeyPositions case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 67d879bdd8bf4..0470aacd4f823 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -175,7 +175,16 @@ case class EnsureRequirements( child case ((child, dist), idx) => if (bestSpecOpt.isDefined && bestSpecOpt.get.isCompatibleWith(specs(idx))) { - child + bestSpecOpt match { + // If keyGroupCompatible = false, we can still perform SPJ + // by shuffling the other side based on join keys (see the else case below). + // Hence we need to ensure that after this call, the outputPartitioning of the + // partitioned side's BatchScanExec is grouped by join keys to match, + // and we do that by pushing down the join keys + case Some(KeyGroupedShuffleSpec(_, _, Some(joinKeyPositions))) => + populateJoinKeyPositions(child, Some(joinKeyPositions)) + case _ => child + } } else { val newPartitioning = bestSpecOpt.map { bestSpec => // Use the best spec to create a new partitioning to re-shuffle this child @@ -578,6 +587,21 @@ case class EnsureRequirements( child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions)) } + + private def populateJoinKeyPositions( + plan: SparkPlan, + joinKeyPositions: Option[Seq[Int]]): SparkPlan = plan match { + case scan: BatchScanExec => + scan.copy( + spjParams = scan.spjParams.copy( + joinKeyPositions = joinKeyPositions + ) + ) + case node => + node.mapChildren(child => populateJoinKeyPositions( + child, joinKeyPositions)) + } + private def reduceCommonPartValues( commonPartValues: Seq[(InternalRow, Int)], expressions: Seq[Expression], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index d77a6e8b8ac16..5e5453b4cd500 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -2168,8 +2168,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { val df = createJoinTestDF(Seq("id" -> "item_id")) val shuffles = collectShuffles(df.queryExecution.executedPlan) - assert(shuffles.size == 2, "SPJ should not be triggered for transform expression with" + - "less join keys than partition keys for now.") + assert(shuffles.size == 1, "SPJ should be triggered") checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), Row(1, "aa", 30.0, 89.0), Row(1, "aa", 40.0, 42.0),