From dc7ba14d8b95599d3c2803e25d861aeada642a32 Mon Sep 17 00:00:00 2001 From: Guo Chenzhao Date: Tue, 21 Aug 2018 16:53:01 +0800 Subject: [PATCH] Support Left Anti Join in data skew feature (#62) --- .../execution/adaptive/HandleSkewedJoin.scala | 5 +- .../execution/adaptive/QueryStageSuite.scala | 149 +++++++++++++++++- 2 files changed, 151 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/HandleSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/HandleSkewedJoin.scala index ebea292f98248..f0d0bbb1823fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/HandleSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/HandleSkewedJoin.scala @@ -29,7 +29,8 @@ import org.apache.spark.sql.internal.SQLConf case class HandleSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { - private val supportedJoinTypes = Inner :: Cross :: LeftSemi :: LeftOuter:: RightOuter :: Nil + private val supportedJoinTypes = + Inner :: Cross :: LeftSemi :: LeftAnti :: LeftOuter :: RightOuter :: Nil private def isSizeSkewed(size: Long, medianSize: Long): Boolean = { size > medianSize * conf.adaptiveSkewedFactor && @@ -116,7 +117,7 @@ case class HandleSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { private def supportSplitOnLeftPartition(joinType: JoinType) = joinType != RightOuter private def supportSplitOnRightPartition(joinType: JoinType) = { - joinType != LeftOuter && joinType != LeftSemi + joinType != LeftOuter && joinType != LeftSemi && joinType != LeftAnti } private def handleSkewedJoin( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryStageSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryStageSuite.scala index f35e235d18de8..43ce1dd384e5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryStageSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryStageSuite.scala @@ -544,7 +544,7 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll { val expectedAnswerForRightOuter = spark .range(0, 100) - .flatMap(i => Seq.fill(100)(i)) + .flatMap(i => Seq.fill(100)(i)) .selectExpr("0 as key", "value") checkAnswer( rightOuterJoin, @@ -578,6 +578,153 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll { } } + test("adaptive skewed join: left semi/anti join and skewed on right side") { + val spark = defaultSparkSession + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_JOIN_ENABLED.key, "false") + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED.key, "true") + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD.key, 10) + withSparkSession(spark) { spark: SparkSession => + val df1 = + spark + .range(0, 10, 1, 2) + .selectExpr("id % 5 as key1", "id as value1") + val df2 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 1 as key2", "id as value2") + + val leftSemiJoin = + df1.join(df2, col("key1") === col("key2"), "left_semi").select(col("key1"), col("value1")) + val leftAntiJoin = + df1.join(df2, col("key1") === col("key2"), "left_anti").select(col("key1"), col("value1")) + + // Before Execution, there is one SortMergeJoin + val smjBeforeExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecutionForLeftSemi.length === 1) + + val smjBeforeExecutionForLeftAnti = leftSemiJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecutionForLeftAnti.length === 1) + + // Check the answer. + val expectedAnswerForLeftSemi = + spark + .range(0, 10) + .filter(_ % 5 == 0) + .selectExpr("id % 5 as key", "id as value") + checkAnswer( + leftSemiJoin, + expectedAnswerForLeftSemi.collect()) + + val expectedAnswerForLeftAnti = + spark + .range(0, 10) + .filter(_ % 5 != 0) + .selectExpr("id % 5 as key", "id as value") + checkAnswer( + leftAntiJoin, + expectedAnswerForLeftAnti.collect()) + + // For the left outer join case: during execution, the SMJ can not be translated to any sub + // joins due to the skewed side is on the right but the join type is left semi + // (not correspond with each other) + val smjAfterExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjAfterExecutionForLeftSemi.length === 1) + + // For the right outer join case: during execution, the SMJ can not be translated to any sub + // joins due to the skewed side is on the right but the join type is left anti + // (not correspond with each other) + val smjAfterExecutionForLeftAnti = leftAntiJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjAfterExecutionForLeftAnti.length === 1) + + } + } + + test("adaptive skewed join: left semi/anti join and skewed on left side") { + val spark = defaultSparkSession + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_JOIN_ENABLED.key, "false") + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED.key, "true") + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD.key, 10) + val MAX_SPLIT = 5 + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_MAX_SPLITS.key, MAX_SPLIT) + withSparkSession(spark) { spark: SparkSession => + val df1 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 1 as key1", "id as value1") + val df2 = + spark + .range(0, 10, 1, 2) + .selectExpr("id % 5 as key2", "id as value2") + + val leftSemiJoin = + df1.join(df2, col("key1") === col("key2"), "left_semi").select(col("key1"), col("value1")) + val leftAntiJoin = + df1.join(df2, col("key1") === col("key2"), "left_anti").select(col("key1"), col("value1")) + + // Before Execution, there is one SortMergeJoin + val smjBeforeExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecutionForLeftSemi.length === 1) + + val smjBeforeExecutionForLeftAnti = leftSemiJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecutionForLeftAnti.length === 1) + + // Check the answer. + val expectedAnswerForLeftSemi = + spark + .range(0, 1000) + .selectExpr("id % 1 as key", "id as value") + checkAnswer( + leftSemiJoin, + expectedAnswerForLeftSemi.collect()) + + val expectedAnswerForLeftAnti = Seq.empty + checkAnswer( + leftAntiJoin, + expectedAnswerForLeftAnti) + + // For the left outer join case: during execution, the SMJ is changed to Union of SMJ + 5 SMJ + // joins due to the skewed side is on the left and the join type is left semi + // (correspond with each other) + val smjAfterExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjAfterExecutionForLeftSemi.length === MAX_SPLIT + 1) + + // For the right outer join case: during execution, the SMJ is changed to Union of SMJ + 5 SMJ + // joins due to the skewed side is on the left and the join type is left anti + // (correspond with each other) + val smjAfterExecutionForLeftAnti = leftAntiJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjAfterExecutionForLeftAnti.length === MAX_SPLIT + 1) + + val queryStageInputs = leftSemiJoin.queryExecution.executedPlan.collect { + case q: ShuffleQueryStageInput => q + } + assert(queryStageInputs.length === 2) + assert(queryStageInputs(0).skewedPartitions === queryStageInputs(1).skewedPartitions) + assert(queryStageInputs(0).skewedPartitions === Some(Set(0))) + + val skewedQueryStageInputs = leftSemiJoin.queryExecution.executedPlan.collect { + case q: SkewedShuffleQueryStageInput => q + } + assert(skewedQueryStageInputs.length === MAX_SPLIT * 2) + + } + } + test("row count statistics, compressed") { val spark = defaultSparkSession withSparkSession(spark) { spark: SparkSession =>