diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index e2e07c1b804a3..842cd1b7d58fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -145,10 +145,10 @@ private[sql] class SQLConf extends Serializable { getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt /** - * By default it will choose sort merge join. + * By default not choose sort merge join. */ private[spark] def autoSortMergeJoin: Boolean = - getConf(AUTO_SORTMERGEJOIN, true.toString).toBoolean + getConf(AUTO_SORTMERGEJOIN, false.toString).toBoolean /** * The default size in bytes to assign to a logical operator's estimation statistics. By default, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index ddddd8c9a14ca..32689f5ca2591 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -140,14 +140,11 @@ case class SortMergeJoin( stop = ordering.compare(leftKey, rightKey) == 0 } if (rightMatches.size > 0) { - stop = false leftMatches = new CompactBuffer[Row]() val leftMatch = leftKey.copy() - while (!stop && leftElement != null) { + while (ordering.compare(leftKey, leftMatch) == 0 && leftElement != null) { leftMatches += leftElement fetchLeft() - // exit loop when run out of left matches - stop = ordering.compare(leftKey, leftMatch) != 0 } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index bba2f223c55dc..a2c9778883389 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -51,6 +51,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j + case j: ShuffledHashJoin => j case j: SortMergeJoin => j } @@ -63,6 +64,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("join operator selection") { cacheManager.clearCache() + val AUTO_SORTMERGEJOIN: Boolean = conf.autoSortMergeJoin + conf.setConf("spark.sql.autoSortMergeJoin", "false") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -76,9 +79,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[HashOuterJoin]), @@ -92,6 +95,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + conf.setConf("spark.sql.autoSortMergeJoin", "true") + Seq( + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + conf.setConf("spark.sql.autoSortMergeJoin", AUTO_SORTMERGEJOIN.toString) } test("broadcasted hash join operator selection") {