From 1186ef5e38a34ff77fa62521de0da73666b0de96 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Jan 2018 08:41:24 -0800 Subject: [PATCH 1/2] fix --- .../apache/spark/sql/execution/CacheManager.scala | 12 ++++++++---- .../sql/execution/joins/BroadcastJoinSuite.scala | 13 +++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index b05fe49a6ac3b..432eb59d6fe57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ResolvedHint} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.storage.StorageLevel @@ -170,9 +170,13 @@ class CacheManager extends Logging { def useCachedData(plan: LogicalPlan): LogicalPlan = { val newPlan = plan transformDown { case currentFragment => - lookupCachedData(currentFragment) - .map(_.cachedRepresentation.withOutput(currentFragment.output)) - .getOrElse(currentFragment) + lookupCachedData(currentFragment).map { cached => + val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output) + currentFragment match { + case hint: ResolvedHint => ResolvedHint(cachedPlan, hint.hints) + case _ => cachedPlan + } + }.getOrElse(currentFragment) } newPlan transformAllExpressions { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 0bcd54e1fceab..2b05e472e27d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -109,6 +109,19 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } + test("broadcast hint is lost") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + df2.cache() + val df3 = df1.join(broadcast(df2), Seq("key"), "inner") + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + assert(numBroadCastHashJoin === 1) + } + } + test("broadcast hint isn't propagated after a join") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") From 72097921f33492160a2784e108d2eb61fa543672 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Jan 2018 12:57:52 -0800 Subject: [PATCH 2/2] rename --- .../apache/spark/sql/execution/joins/BroadcastJoinSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 2b05e472e27d7..1704bc8376f0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -109,7 +109,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } - test("broadcast hint is lost") { + test("broadcast hint is retained after using the cached data") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")