Skip to content

Commit

Permalink
Read bucketed tables obeys numShufflePartitions
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed Nov 12, 2019
1 parent 45109c7 commit 73a4943
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,6 @@ object SQLConf {
.checkValue(_ > 0, "The value of spark.sql.shuffle.partitions must be positive")
.createWithDefault(200)

val SHUFFLE_WITHOUT_SHUFFLE_SIDE_RATIO =
buildConf("spark.sql.shuffle.withoutShuffleSideRatio")
.doc("The maximum number of without shuffle partition ratio lower than this config " +
"will not add shuffle exchange for it.")
.doubleConf
.checkValue(ratio => ratio > 0 && ratio <= 1, "The ratio value must be in [0, 1].")
.createWithDefault(1.0)

val ADAPTIVE_EXECUTION_ENABLED = buildConf("spark.sql.adaptive.enabled")
.doc("When true, enable adaptive query execution.")
.booleanConf
Expand Down Expand Up @@ -2170,8 +2162,6 @@ class SQLConf extends Serializable with Logging {

def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS)

def withoutShuffleSideRatio: Double = getConf(SHUFFLE_WITHOUT_SHUFFLE_SIDE_RATIO)

def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED)

def targetPostShuffleInputSize: Long = getConf(SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,13 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
numPartitionsSet.headOption
}

val maxNumPartition = childrenNumPartitions.max
val withoutShuffleChildrenNumPartitions =
val nonShuffleChildrenNumPartitions =
childrenIndexes.filterNot(children(_).isInstanceOf[ShuffleExchangeExec])
.map(children(_).outputPartitioning.numPartitions).toSet
val expectedChildrenNumPartitions = if (withoutShuffleChildrenNumPartitions.nonEmpty) {
val withoutShuffleMaxNumPartition = withoutShuffleChildrenNumPartitions.max
if (withoutShuffleMaxNumPartition * 1.0 / maxNumPartition >= conf.withoutShuffleSideRatio) {
withoutShuffleMaxNumPartition
} else {
maxNumPartition
}
val expectedChildrenNumPartitions = if (nonShuffleChildrenNumPartitions.nonEmpty) {
math.max(nonShuffleChildrenNumPartitions.max, conf.numShufflePartitions)
} else {
maxNumPartition
childrenNumPartitions.max
}

val targetNumPartitions = requiredNumPartitions.getOrElse(expectedChildrenNumPartitions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

package org.apache.spark.sql.execution.adaptive

import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.execution.{ReusedSubqueryExec, SparkPlan}
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildRight, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -474,32 +474,4 @@ class AdaptiveQueryExecSuite
}
}
}

test("Enable adaptive execution should not add ShuffleExchange") {
def findTopLevelShuffleExchangeExec(df: DataFrame): Seq[ShuffleExchangeExec] = {
collect(df.queryExecution.executedPlan) {
case s: ShuffleExchangeExec => s
}
}

val bucketedTableName = "bucketed_table"
withTable(bucketedTableName) {
withSQLConf(
SQLConf.SHUFFLE_PARTITIONS.key -> "4",
SQLConf.SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS.key -> "5",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
spark.range(10).write.bucketBy(4, "id").sortBy("id").saveAsTable(bucketedTableName)
val bucketedTable = spark.table(bucketedTableName)

Seq(false, true).foreach { isAdaptive =>
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> s"$isAdaptive") {
assert(
findTopLevelShuffleExchangeExec(bucketedTable.join(spark.range(8), "id")).size === 1)
assert(
findTopLevelShuffleExchangeExec(bucketedTable.join(bucketedTable, "id")).size === 0)
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.datasources.BucketingUtils
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
Expand Down Expand Up @@ -382,8 +383,16 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
joined.sort("bucketed_table1.k", "bucketed_table2.k"),
df1.join(df2, joinCondition(df1, df2), joinType).sort("df1.k", "df2.k"))

assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoinExec])
val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoinExec]
val joinOperator = if (joined.sqlContext.conf.adaptiveExecutionEnabled) {
val executedPlan =
joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
assert(executedPlan.isInstanceOf[SortMergeJoinExec])
executedPlan.asInstanceOf[SortMergeJoinExec]
} else {
val executedPlan = joined.queryExecution.executedPlan
assert(executedPlan.isInstanceOf[SortMergeJoinExec])
executedPlan.asInstanceOf[SortMergeJoinExec]
}

// check existence of shuffle
assert(
Expand Down Expand Up @@ -795,43 +804,22 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
}
}

test("Support spark.sql.shuffle.withoutShuffleSideRatio") {
// numBuckets >= spark.sql.shuffle.partitions
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
val bucketSpec = Some(BucketSpec(6, Seq("i", "j"), Nil))
val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = false)
val bucketedTableTestSpecRight = BucketedTableTestSpec(None, expectedShuffle = true)
testBucketing(
bucketedTableTestSpecLeft = bucketedTableTestSpecLeft,
bucketedTableTestSpecRight = bucketedTableTestSpecRight,
joinCondition = joinCondition(Seq("i", "j"))
)
}

// numBuckets < spark.sql.shuffle.partitions
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
val bucketSpec = Some(BucketSpec(4, Seq("i", "j"), Nil))
val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = true)
val bucketedTableTestSpecRight = BucketedTableTestSpec(None, expectedShuffle = true)
testBucketing(
bucketedTableTestSpecLeft = bucketedTableTestSpecLeft,
bucketedTableTestSpecRight = bucketedTableTestSpecRight,
joinCondition = joinCondition(Seq("i", "j"))
)
}

// numBuckets < spark.sql.shuffle.partitions and withoutShuffleSideRatio = 0.1
test("Read bucketed tables obeys numShufflePartitions") {
withSQLConf(
SQLConf.SHUFFLE_PARTITIONS.key -> "5",
SQLConf.SHUFFLE_WITHOUT_SHUFFLE_SIDE_RATIO.key -> "0.1") {
val bucketSpec = Some(BucketSpec(4, Seq("i", "j"), Nil))
val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = false)
val bucketedTableTestSpecRight = BucketedTableTestSpec(None, expectedShuffle = true)
testBucketing(
bucketedTableTestSpecLeft = bucketedTableTestSpecLeft,
bucketedTableTestSpecRight = bucketedTableTestSpecRight,
joinCondition = joinCondition(Seq("i", "j"))
)
SQLConf.SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS.key -> "7") {
val bucketSpec = Some(BucketSpec(6, Seq("i", "j"), Nil))
Seq(false, true).foreach { enableAdaptive =>
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> s"$enableAdaptive") {
val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = false)
val bucketedTableTestSpecRight = BucketedTableTestSpec(None, expectedShuffle = true)
testBucketing(
bucketedTableTestSpecLeft = bucketedTableTestSpecLeft,
bucketedTableTestSpecRight = bucketedTableTestSpecRight,
joinCondition = joinCondition(Seq("i", "j"))
)
}
}
}
}
}

0 comments on commit 73a4943

Please sign in to comment.