Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23214][SQL] cached data should not carry extra hint info #20394

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,17 @@ class CacheManager extends Logging {
/** Replaces segments of the given logical plan with cached versions where possible. */
def useCachedData(plan: LogicalPlan): LogicalPlan = {
val newPlan = plan transformDown {
// Do not lookup the cache by hint node. Hint node is special, we should ignore it when
// canonicalizing plans, so that plans which are same except hint can hit the same cache.
// However, we also want to keep the hint info after cache lookup. Here we skip the hint
// node, so that the returned caching plan won't replace the hint node and drop the hint info
// from the original plan.
case hint: ResolvedHint => hint
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A small clean up for #20365


case currentFragment =>
lookupCachedData(currentFragment).map { cached =>
val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output)
currentFragment match {
case hint: ResolvedHint => ResolvedHint(cachedPlan, hint.hints)
case _ => cachedPlan
}
}.getOrElse(currentFragment)
lookupCachedData(currentFragment)
.map(_.cachedRepresentation.withOutput(currentFragment.output))
.getOrElse(currentFragment)
Copy link
Member

@dongjoon-hyun dongjoon-hyun Jan 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pinging me, @cloud-fan . I see.

}

newPlan transformAllExpressions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.LongAccumulator
Expand Down Expand Up @@ -62,8 +62,8 @@ case class InMemoryRelation(
@transient child: SparkPlan,
tableName: Option[String])(
@transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
statsOfPlanToCache: Statistics = null)
val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
statsOfPlanToCache: Statistics)
extends logical.LeafNode with MultiInstanceRelation {

override protected def innerChildren: Seq[SparkPlan] = Seq(child)
Expand All @@ -73,11 +73,16 @@ case class InMemoryRelation(
@transient val partitionStatistics = new PartitionStatistics(output)

override def computeStats(): Statistics = {
if (batchStats.value == 0L) {
// Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache
statsOfPlanToCache
if (sizeInBytesStats.value == 0L) {
// Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache.
// Note that we should drop the hint info here. We may cache a plan whose root node is a hint
// node. When we lookup the cache with a semantically same plan without hint info, the plan
// returned by cache lookup should not have hint info. If we lookup the cache with a
// semantically same plan with a different hint info, `CacheManager.useCachedData` will take
// care of it and retain the hint info in the lookup input plan.
statsOfPlanToCache.copy(hints = HintInfo())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I agree with this. If we cache a plan with a hint, then it is reasonable to expect that the hint is still in the plan. We do the same with temporary views.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new behavior we introduced in 2.3. I will first keep the behavior unchanged and merge it to 2.3.

We can have more discussion in the next release.

} else {
Statistics(sizeInBytes = batchStats.value.longValue)
Statistics(sizeInBytes = sizeInBytesStats.value.longValue)
}
}

Expand Down Expand Up @@ -122,7 +127,7 @@ case class InMemoryRelation(
rowCount += 1
}

batchStats.add(totalSize)
sizeInBytesStats.add(totalSize)

val stats = InternalRow.fromSeq(
columnBuilders.flatMap(_.columnStats.collectedStatistics))
Expand All @@ -144,7 +149,7 @@ case class InMemoryRelation(
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
InMemoryRelation(
newOutput, useCompression, batchSize, storageLevel, child, tableName)(
_cachedColumnBuffers, batchStats, statsOfPlanToCache)
_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
}

override def newInstance(): this.type = {
Expand All @@ -156,12 +161,12 @@ case class InMemoryRelation(
child,
tableName)(
_cachedColumnBuffers,
batchStats,
sizeInBytesStats,
statsOfPlanToCache).asInstanceOf[this.type]
}

def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers

override protected def otherCopyArgs: Seq[AnyRef] =
Seq(_cachedColumnBuffers, batchStats, statsOfPlanToCache)
Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
}
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
val toBeCleanedAccIds = new HashSet[Long]

val accId1 = spark.table("t1").queryExecution.withCachedData.collect {
case i: InMemoryRelation => i.batchStats.id
case i: InMemoryRelation => i.sizeInBytesStats.id
}.head
toBeCleanedAccIds += accId1

val accId2 = spark.table("t1").queryExecution.withCachedData.collect {
case i: InMemoryRelation => i.batchStats.id
case i: InMemoryRelation => i.sizeInBytesStats.id
}.head
toBeCleanedAccIds += accId2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(cached, expectedAnswer)

// Check that the right size was calculated.
assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize)
assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize)
}

test("access primitive-type columns in CachedBatch without whole stage codegen") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import scala.reflect.ClassTag
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft}
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -70,8 +71,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
private def testBroadcastJoin[T: ClassTag](
joinType: String,
forceBroadcast: Boolean = false): SparkPlan = {
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some code style fixing


// Comparison at the end is for broadcast left semi join
val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
Expand Down Expand Up @@ -109,61 +110,89 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
}
}

test("broadcast hint is retained after using the cached data") {
test("SPARK-23192: broadcast hint should be retained after using the cached data") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
df2.cache()
val df3 = df1.join(broadcast(df2), Seq("key"), "inner")
val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
case b: BroadcastHashJoinExec => b
}.size
assert(numBroadCastHashJoin === 1)
try {
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
df2.cache()
val df3 = df1.join(broadcast(df2), Seq("key"), "inner")
val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
case b: BroadcastHashJoinExec => b
}.size
assert(numBroadCastHashJoin === 1)
} finally {
spark.catalog.clearCache()
}
}
}

test("SPARK-23214: cached data should not carry extra hint info") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
try {
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
broadcast(df2).cache()

val df3 = df1.join(df2, Seq("key"), "inner")
val numCachedPlan = df3.queryExecution.executedPlan.collect {
case i: InMemoryTableScanExec => i
}.size
// df2 should be cached.
assert(numCachedPlan === 1)

val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
case b: BroadcastHashJoinExec => b
}.size
// df2 should not be broadcasted.
assert(numBroadCastHashJoin === 0)
} finally {
spark.catalog.clearCache()
}
}
}

test("broadcast hint isn't propagated after a join") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
val df3 = df1.join(broadcast(df2), Seq("key"), "inner").drop(df2("key"))

val df4 = spark.createDataFrame(Seq((1, "5"), (2, "5"))).toDF("key", "value")
val df4 = Seq((1, "5"), (2, "5")).toDF("key", "value")
val df5 = df4.join(df3, Seq("key"), "inner")

val plan =
EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan)
val plan = EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan)

assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1)
assert(plan.collect { case p: SortMergeJoinExec => p }.size === 1)
}
}

private def assertBroadcastJoin(df : Dataset[Row]) : Unit = {
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
val joined = df1.join(df, Seq("key"), "inner")

val plan =
EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan)
val plan = EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan)

assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1)
}

test("broadcast hint programming API") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "2"))).toDF("key", "value")
val df2 = Seq((1, "1"), (2, "2"), (3, "2")).toDF("key", "value")
val broadcasted = broadcast(df2)
val df3 = spark.createDataFrame(Seq((2, "2"), (3, "3"))).toDF("key", "value")

val cases = Seq(broadcasted.limit(2),
broadcasted.filter("value < 10"),
broadcasted.sample(true, 0.5),
broadcasted.distinct(),
broadcasted.groupBy("value").agg(min($"key").as("key")),
// except and intersect are semi/anti-joins which won't return more data then
// their left argument, so the broadcast hint should be propagated here
broadcasted.except(df3),
broadcasted.intersect(df3))
val df3 = Seq((2, "2"), (3, "3")).toDF("key", "value")

val cases = Seq(
broadcasted.limit(2),
broadcasted.filter("value < 10"),
broadcasted.sample(true, 0.5),
broadcasted.distinct(),
broadcasted.groupBy("value").agg(min($"key").as("key")),
// except and intersect are semi/anti-joins which won't return more data then
// their left argument, so the broadcast hint should be propagated here
broadcasted.except(df3),
broadcasted.intersect(df3))

cases.foreach(assertBroadcastJoin)
}
Expand Down Expand Up @@ -240,9 +269,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
test("Shouldn't change broadcast join buildSide if user clearly specified") {

withTempView("t1", "t2") {
spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1")
spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value")
.createTempView("t2")
Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2")

val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
Expand Down Expand Up @@ -292,9 +320,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
test("Shouldn't bias towards build right if user didn't specify") {

withTempView("t1", "t2") {
spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1")
spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value")
.createTempView("t2")
Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2")

val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
Expand Down