diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 19595eef10b34..f8e980747bf2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -434,8 +434,13 @@ object KeyGroupedPartitioning { val projectedOriginalPartitionValues = originalPartitionValues.map(project(expressions, projectionPositions, _)) - KeyGroupedPartitioning(projectedExpressions, projectedPartitionValues.length, - projectedPartitionValues, projectedOriginalPartitionValues) + val finalPartitionValues = projectedPartitionValues + .map(InternalRowComparableWrapper(_, projectedExpressions)) + .distinct + .map(_.row) + + KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length, + finalPartitionValues, projectedOriginalPartitionValues) } def project( @@ -871,20 +876,12 @@ case class KeyGroupedShuffleSpec( if (results.forall(p => p.isEmpty)) None else Some(results) } - override def canCreatePartitioning: Boolean = { - // Allow one side shuffle for SPJ for now only if partially-clustered is not enabled - // and for join keys less than partition keys only if transforms are not enabled. - val checkExprType = if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { - e: Expression => e.isInstanceOf[AttributeReference] - } else { - e: Expression => e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression] - } + override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled && - partitioning.expressions.forall(checkExprType) - } - - + partitioning.expressions.forall { e => + e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression] + } override def createPartitioning(clustering: Seq[Expression]): Partitioning = { val newExpressions: Seq[Expression] = clustering.zip(partitioning.expressions).map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index f949dbf71a371..997576a396d20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -118,16 +118,29 @@ case class BatchScanExec( override def outputPartitioning: Partitioning = { super.outputPartitioning match { - case k: KeyGroupedPartitioning if spjParams.commonPartitionValues.isDefined => - // We allow duplicated partition values if - // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true - val newPartValues = spjParams.commonPartitionValues.get.flatMap { - case (partValue, numSplits) => Seq.fill(numSplits)(partValue) - } + case k: KeyGroupedPartitioning => val expressions = spjParams.joinKeyPositions match { case Some(projectionPositions) => projectionPositions.map(i => k.expressions(i)) case _ => k.expressions } + + val newPartValues = spjParams.commonPartitionValues match { + case Some(commonPartValues) => + // We allow duplicated partition values if + // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true + commonPartValues.flatMap { + case (partValue, numSplits) => Seq.fill(numSplits)(partValue) + } + case None => + spjParams.joinKeyPositions match { + case Some(projectionPositions) => k.partitionValues.map{r => + val projectedRow = KeyGroupedPartitioning.project(expressions, + projectionPositions, r) + InternalRowComparableWrapper(projectedRow, expressions) + }.distinct.map(_.row) + case _ => k.partitionValues + } + } k.copy(expressions = expressions, numPartitions = newPartValues.length, partitionValues = newPartValues) case p => p @@ -279,7 +292,8 @@ case class StoragePartitionJoinParams( case other: StoragePartitionJoinParams => this.commonPartitionValues == other.commonPartitionValues && this.replicatePartitions == other.replicatePartitions && - this.applyPartialClustering == other.applyPartialClustering + this.applyPartialClustering == other.applyPartialClustering && + this.joinKeyPositions == other.joinKeyPositions case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 67d879bdd8bf4..0470aacd4f823 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -175,7 +175,16 @@ case class EnsureRequirements( child case ((child, dist), idx) => if (bestSpecOpt.isDefined && bestSpecOpt.get.isCompatibleWith(specs(idx))) { - child + bestSpecOpt match { + // If keyGroupCompatible = false, we can still perform SPJ + // by shuffling the other side based on join keys (see the else case below). + // Hence we need to ensure that after this call, the outputPartitioning of the + // partitioned side's BatchScanExec is grouped by join keys to match, + // and we do that by pushing down the join keys + case Some(KeyGroupedShuffleSpec(_, _, Some(joinKeyPositions))) => + populateJoinKeyPositions(child, Some(joinKeyPositions)) + case _ => child + } } else { val newPartitioning = bestSpecOpt.map { bestSpec => // Use the best spec to create a new partitioning to re-shuffle this child @@ -578,6 +587,21 @@ case class EnsureRequirements( child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions)) } + + private def populateJoinKeyPositions( + plan: SparkPlan, + joinKeyPositions: Option[Seq[Int]]): SparkPlan = plan match { + case scan: BatchScanExec => + scan.copy( + spjParams = scan.spjParams.copy( + joinKeyPositions = joinKeyPositions + ) + ) + case node => + node.mapChildren(child => populateJoinKeyPositions( + child, joinKeyPositions)) + } + private def reduceCommonPartValues( commonPartValues: Seq[(InternalRow, Int)], expressions: Seq[Expression], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index d77a6e8b8ac16..5e5453b4cd500 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -2168,8 +2168,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { val df = createJoinTestDF(Seq("id" -> "item_id")) val shuffles = collectShuffles(df.queryExecution.executedPlan) - assert(shuffles.size == 2, "SPJ should not be triggered for transform expression with" + - "less join keys than partition keys for now.") + assert(shuffles.size == 1, "SPJ should be triggered") checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), Row(1, "aa", 30.0, 89.0), Row(1, "aa", 40.0, 42.0),