Skip to content

Commit

Permalink
[SPARK-24495][SQL] EnsureRequirement returns wrong plan when reorderi…
Browse files Browse the repository at this point in the history
…ng equal keys

`EnsureRequirement` in its `reorder` method currently assumes that the same key appears only once in the join condition. This of course might not be the case, and when it is not satisfied, it returns a wrong plan which produces a wrong result of the query.

added UT

Author: Marco Gaido <[email protected]>

Closes #21529 from mgaido91/SPARK-24495.

(cherry picked from commit fdadc4b)
Signed-off-by: Xiao Li <[email protected]>
  • Loading branch information
mgaido91 authored and gatorsmile committed Jun 14, 2018
1 parent a2f65eb commit e6bf325
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution.exchange

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -227,9 +228,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
val leftKeysBuffer = ArrayBuffer[Expression]()
val rightKeysBuffer = ArrayBuffer[Expression]()
val pickedIndexes = mutable.Set[Int]()
val keysAndIndexes = currentOrderOfKeys.zipWithIndex

expectedOrderOfKeys.foreach(expression => {
val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
val index = keysAndIndexes.find { case (e, idx) =>
// As we may have the same key used many times, we need to filter out its occurrence we
// have already used.
e.semanticEquals(expression) && !pickedIndexes.contains(idx)
}.map(_._2).get
pickedIndexes += index
leftKeysBuffer.append(leftKeys(index))
rightKeysBuffer.append(rightKeys(index))
})
Expand Down Expand Up @@ -270,7 +278,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
* partitioning of the join nodes' children.
*/
private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
plan.transformUp {
plan match {
case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left,
right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
Expand All @@ -288,6 +296,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right)

case other => other
}
}

Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -882,4 +882,15 @@ class JoinSuite extends QueryTest with SharedSQLContext {
checkAnswer(df, Row(3, 8, 7, 2) :: Row(3, 8, 4, 2) :: Nil)
}
}

test("SPARK-24495: Join may return wrong result when having duplicated equal-join keys") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1",
SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df1 = spark.range(0, 100, 1, 2)
val df2 = spark.range(100).select($"id".as("b1"), (- $"id").as("b2"))
val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2").select($"b1", $"b2", $"id")
checkAnswer(res, Row(0, 0, 0))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,23 @@ class PlannerSuite extends SharedSQLContext {
requiredOrdering = Seq(orderingA, orderingB),
shouldHaveSort = true)
}

test("SPARK-24495: EnsureRequirements can return wrong plan when reusing the same key in join") {
val plan1 = DummySparkPlan(outputOrdering = Seq(orderingA),
outputPartitioning = HashPartitioning(exprA :: exprA :: Nil, 5))
val plan2 = DummySparkPlan(outputOrdering = Seq(orderingB),
outputPartitioning = HashPartitioning(exprB :: Nil, 5))
val smjExec = SortMergeJoinExec(
exprA :: exprA :: Nil, exprB :: exprC :: Nil, Inner, None, plan1, plan2)

val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec)
outputPlan match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) =>
assert(leftKeys == Seq(exprA, exprA))
assert(rightKeys == Seq(exprB, exprC))
case _ => fail()
}
}
}

// Used for unit-testing EnsureRequirements
Expand Down

0 comments on commit e6bf325

Please sign in to comment.