diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index c56a5c015f32d..866b382a1d808 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -83,7 +83,24 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { numPartitionsSet.headOption } - val targetNumPartitions = requiredNumPartitions.getOrElse(childrenNumPartitions.max) + // If there are non-shuffle children that satisfy the required distribution, we have + // some tradeoffs when picking the expected number of shuffle partitions: + // 1. We should avoid shuffling these children. + // 2. We should have a reasonable parallelism. + val nonShuffleChildrenNumPartitions = + childrenIndexes.map(children).filterNot(_.isInstanceOf[ShuffleExchangeExec]) + .map(_.outputPartitioning.numPartitions) + val expectedChildrenNumPartitions = if (nonShuffleChildrenNumPartitions.nonEmpty) { + // Here we pick the max number of partitions among these non-shuffle children as the + // expected number of shuffle partitions. However, if it's smaller than + // `conf.numShufflePartitions`, we pick `conf.numShufflePartitions` as the + // expected number of shuffle partitions. + math.max(nonShuffleChildrenNumPartitions.max, conf.numShufflePartitions) + } else { + childrenNumPartitions.max + } + + val targetNumPartitions = requiredNumPartitions.getOrElse(expectedChildrenNumPartitions) children = children.zip(requiredChildDistributions).zipWithIndex.map { case ((child, distribution), index) if childrenIndexes.contains(index) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala index 4d408cd8ebd70..21ec1ac9bda08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala @@ -274,6 +274,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA .setMaster("local[*]") .setAppName("test") .set(UI_ENABLED, false) + .set(SQLConf.SHUFFLE_PARTITIONS.key, "5") .set(SQLConf.SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS.key, "5") .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") @@ -507,7 +508,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA join, expectedAnswer.collect()) - // Then, let's make sure we do not reduce number of ppst shuffle partitions. + // Then, let's make sure we do not reduce number of post shuffle partitions. val finalPlan = join.queryExecution.executedPlan .asInstanceOf[AdaptiveSparkPlanExec].executedPlan val shuffleReaders = finalPlan.collect { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 7043b6d396977..a585f215ad681 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec @@ -382,8 +383,16 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { joined.sort("bucketed_table1.k", "bucketed_table2.k"), df1.join(df2, joinCondition(df1, df2), joinType).sort("df1.k", "df2.k")) - assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoinExec]) - val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoinExec] + val joinOperator = if (joined.sqlContext.conf.adaptiveExecutionEnabled) { + val executedPlan = + joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan + assert(executedPlan.isInstanceOf[SortMergeJoinExec]) + executedPlan.asInstanceOf[SortMergeJoinExec] + } else { + val executedPlan = joined.queryExecution.executedPlan + assert(executedPlan.isInstanceOf[SortMergeJoinExec]) + executedPlan.asInstanceOf[SortMergeJoinExec] + } // check existence of shuffle assert( @@ -795,4 +804,22 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { } } + test("SPARK-29655 Read bucketed tables obeys spark.sql.shuffle.partitions") { + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS.key -> "7") { + val bucketSpec = Some(BucketSpec(6, Seq("i", "j"), Nil)) + Seq(false, true).foreach { enableAdaptive => + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> s"$enableAdaptive") { + val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = false) + val bucketedTableTestSpecRight = BucketedTableTestSpec(None, expectedShuffle = true) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, + bucketedTableTestSpecRight = bucketedTableTestSpecRight, + joinCondition = joinCondition(Seq("i", "j")) + ) + } + } + } + } }