-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23214][SQL] cached data should not carry extra hint info #20394
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -169,14 +169,17 @@ class CacheManager extends Logging { | |
/** Replaces segments of the given logical plan with cached versions where possible. */ | ||
def useCachedData(plan: LogicalPlan): LogicalPlan = { | ||
val newPlan = plan transformDown { | ||
// Do not lookup the cache by hint node. Hint node is special, we should ignore it when | ||
// canonicalizing plans, so that plans which are same except hint can hit the same cache. | ||
// However, we also want to keep the hint info after cache lookup. Here we skip the hint | ||
// node, so that the returned caching plan won't replace the hint node and drop the hint info | ||
// from the original plan. | ||
case hint: ResolvedHint => hint | ||
|
||
case currentFragment => | ||
lookupCachedData(currentFragment).map { cached => | ||
val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output) | ||
currentFragment match { | ||
case hint: ResolvedHint => ResolvedHint(cachedPlan, hint.hints) | ||
case _ => cachedPlan | ||
} | ||
}.getOrElse(currentFragment) | ||
lookupCachedData(currentFragment) | ||
.map(_.cachedRepresentation.withOutput(currentFragment.output)) | ||
.getOrElse(currentFragment) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for pinging me, @cloud-fan . I see. |
||
} | ||
|
||
newPlan transformAllExpressions { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow | |
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation | ||
import org.apache.spark.sql.catalyst.expressions._ | ||
import org.apache.spark.sql.catalyst.plans.logical | ||
import org.apache.spark.sql.catalyst.plans.logical.Statistics | ||
import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics} | ||
import org.apache.spark.sql.execution.SparkPlan | ||
import org.apache.spark.storage.StorageLevel | ||
import org.apache.spark.util.LongAccumulator | ||
|
@@ -62,8 +62,8 @@ case class InMemoryRelation( | |
@transient child: SparkPlan, | ||
tableName: Option[String])( | ||
@transient var _cachedColumnBuffers: RDD[CachedBatch] = null, | ||
val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, | ||
statsOfPlanToCache: Statistics = null) | ||
val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, | ||
statsOfPlanToCache: Statistics) | ||
extends logical.LeafNode with MultiInstanceRelation { | ||
|
||
override protected def innerChildren: Seq[SparkPlan] = Seq(child) | ||
|
@@ -73,11 +73,16 @@ case class InMemoryRelation( | |
@transient val partitionStatistics = new PartitionStatistics(output) | ||
|
||
override def computeStats(): Statistics = { | ||
if (batchStats.value == 0L) { | ||
// Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache | ||
statsOfPlanToCache | ||
if (sizeInBytesStats.value == 0L) { | ||
// Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. | ||
// Note that we should drop the hint info here. We may cache a plan whose root node is a hint | ||
// node. When we lookup the cache with a semantically same plan without hint info, the plan | ||
// returned by cache lookup should not have hint info. If we lookup the cache with a | ||
// semantically same plan with a different hint info, `CacheManager.useCachedData` will take | ||
// care of it and retain the hint info in the lookup input plan. | ||
statsOfPlanToCache.copy(hints = HintInfo()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure I agree with this. If we cache a plan with a hint, then it is reasonable to expect that the hint is still in the plan. We do the same with temporary views. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a new behavior we introduced in 2.3. I will first keep the behavior unchanged and merge it to 2.3. We can have more discussion in the next release. |
||
} else { | ||
Statistics(sizeInBytes = batchStats.value.longValue) | ||
Statistics(sizeInBytes = sizeInBytesStats.value.longValue) | ||
} | ||
} | ||
|
||
|
@@ -122,7 +127,7 @@ case class InMemoryRelation( | |
rowCount += 1 | ||
} | ||
|
||
batchStats.add(totalSize) | ||
sizeInBytesStats.add(totalSize) | ||
|
||
val stats = InternalRow.fromSeq( | ||
columnBuilders.flatMap(_.columnStats.collectedStatistics)) | ||
|
@@ -144,7 +149,7 @@ case class InMemoryRelation( | |
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { | ||
InMemoryRelation( | ||
newOutput, useCompression, batchSize, storageLevel, child, tableName)( | ||
_cachedColumnBuffers, batchStats, statsOfPlanToCache) | ||
_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) | ||
} | ||
|
||
override def newInstance(): this.type = { | ||
|
@@ -156,12 +161,12 @@ case class InMemoryRelation( | |
child, | ||
tableName)( | ||
_cachedColumnBuffers, | ||
batchStats, | ||
sizeInBytesStats, | ||
statsOfPlanToCache).asInstanceOf[this.type] | ||
} | ||
|
||
def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers | ||
|
||
override protected def otherCopyArgs: Seq[AnyRef] = | ||
Seq(_cachedColumnBuffers, batchStats, statsOfPlanToCache) | ||
Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,8 @@ import scala.reflect.ClassTag | |
import org.apache.spark.AccumulatorSuite | ||
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} | ||
import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} | ||
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, WholeStageCodegenExec} | ||
import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} | ||
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec | ||
import org.apache.spark.sql.execution.exchange.EnsureRequirements | ||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.internal.SQLConf | ||
|
@@ -70,8 +71,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { | |
private def testBroadcastJoin[T: ClassTag]( | ||
joinType: String, | ||
forceBroadcast: Boolean = false): SparkPlan = { | ||
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") | ||
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") | ||
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") | ||
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. some code style fixing |
||
|
||
// Comparison at the end is for broadcast left semi join | ||
val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") | ||
|
@@ -109,61 +110,89 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { | |
} | ||
} | ||
|
||
test("broadcast hint is retained after using the cached data") { | ||
test("SPARK-23192: broadcast hint should be retained after using the cached data") { | ||
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { | ||
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") | ||
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") | ||
df2.cache() | ||
val df3 = df1.join(broadcast(df2), Seq("key"), "inner") | ||
val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { | ||
case b: BroadcastHashJoinExec => b | ||
}.size | ||
assert(numBroadCastHashJoin === 1) | ||
try { | ||
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") | ||
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") | ||
df2.cache() | ||
val df3 = df1.join(broadcast(df2), Seq("key"), "inner") | ||
val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { | ||
case b: BroadcastHashJoinExec => b | ||
}.size | ||
assert(numBroadCastHashJoin === 1) | ||
} finally { | ||
spark.catalog.clearCache() | ||
} | ||
} | ||
} | ||
|
||
test("SPARK-23214: cached data should not carry extra hint info") { | ||
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { | ||
try { | ||
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") | ||
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") | ||
broadcast(df2).cache() | ||
|
||
val df3 = df1.join(df2, Seq("key"), "inner") | ||
val numCachedPlan = df3.queryExecution.executedPlan.collect { | ||
case i: InMemoryTableScanExec => i | ||
}.size | ||
// df2 should be cached. | ||
assert(numCachedPlan === 1) | ||
|
||
val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { | ||
case b: BroadcastHashJoinExec => b | ||
}.size | ||
// df2 should not be broadcasted. | ||
assert(numBroadCastHashJoin === 0) | ||
} finally { | ||
spark.catalog.clearCache() | ||
} | ||
} | ||
} | ||
|
||
test("broadcast hint isn't propagated after a join") { | ||
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { | ||
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") | ||
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") | ||
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") | ||
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") | ||
val df3 = df1.join(broadcast(df2), Seq("key"), "inner").drop(df2("key")) | ||
|
||
val df4 = spark.createDataFrame(Seq((1, "5"), (2, "5"))).toDF("key", "value") | ||
val df4 = Seq((1, "5"), (2, "5")).toDF("key", "value") | ||
val df5 = df4.join(df3, Seq("key"), "inner") | ||
|
||
val plan = | ||
EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan) | ||
val plan = EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan) | ||
|
||
assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) | ||
assert(plan.collect { case p: SortMergeJoinExec => p }.size === 1) | ||
} | ||
} | ||
|
||
private def assertBroadcastJoin(df : Dataset[Row]) : Unit = { | ||
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") | ||
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") | ||
val joined = df1.join(df, Seq("key"), "inner") | ||
|
||
val plan = | ||
EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan) | ||
val plan = EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan) | ||
|
||
assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) | ||
} | ||
|
||
test("broadcast hint programming API") { | ||
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { | ||
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "2"))).toDF("key", "value") | ||
val df2 = Seq((1, "1"), (2, "2"), (3, "2")).toDF("key", "value") | ||
val broadcasted = broadcast(df2) | ||
val df3 = spark.createDataFrame(Seq((2, "2"), (3, "3"))).toDF("key", "value") | ||
|
||
val cases = Seq(broadcasted.limit(2), | ||
broadcasted.filter("value < 10"), | ||
broadcasted.sample(true, 0.5), | ||
broadcasted.distinct(), | ||
broadcasted.groupBy("value").agg(min($"key").as("key")), | ||
// except and intersect are semi/anti-joins which won't return more data then | ||
// their left argument, so the broadcast hint should be propagated here | ||
broadcasted.except(df3), | ||
broadcasted.intersect(df3)) | ||
val df3 = Seq((2, "2"), (3, "3")).toDF("key", "value") | ||
|
||
val cases = Seq( | ||
broadcasted.limit(2), | ||
broadcasted.filter("value < 10"), | ||
broadcasted.sample(true, 0.5), | ||
broadcasted.distinct(), | ||
broadcasted.groupBy("value").agg(min($"key").as("key")), | ||
// except and intersect are semi/anti-joins which won't return more data then | ||
// their left argument, so the broadcast hint should be propagated here | ||
broadcasted.except(df3), | ||
broadcasted.intersect(df3)) | ||
|
||
cases.foreach(assertBroadcastJoin) | ||
} | ||
|
@@ -240,9 +269,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { | |
test("Shouldn't change broadcast join buildSide if user clearly specified") { | ||
|
||
withTempView("t1", "t2") { | ||
spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") | ||
spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") | ||
.createTempView("t2") | ||
Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") | ||
Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") | ||
|
||
val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes | ||
val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes | ||
|
@@ -292,9 +320,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { | |
test("Shouldn't bias towards build right if user didn't specify") { | ||
|
||
withTempView("t1", "t2") { | ||
spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") | ||
spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") | ||
.createTempView("t2") | ||
Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") | ||
Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") | ||
|
||
val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes | ||
val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A small clean up for #20365