diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index d68aeb275afda..dbdda27d533f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ResolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, LogicalPlan, ResolvedHint} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.storage.StorageLevel @@ -97,7 +97,7 @@ class CacheManager extends Logging { val inMemoryRelation = InMemoryRelation( sparkSession.sessionState.conf.useCompression, sparkSession.sessionState.conf.columnBatchSize, storageLevel, - sparkSession.sessionState.executePlan(planToCache).executedPlan, + sparkSession.sessionState.executePlan(AnalysisBarrier(planToCache)).executedPlan, tableName, planToCache.stats) cachedData.add(CachedData(planToCache, inMemoryRelation)) @@ -146,7 +146,7 @@ class CacheManager extends Logging { useCompression = cd.cachedRepresentation.useCompression, batchSize = cd.cachedRepresentation.batchSize, storageLevel = cd.cachedRepresentation.storageLevel, - child = spark.sessionState.executePlan(cd.plan).executedPlan, + child = spark.sessionState.executePlan(AnalysisBarrier(cd.plan)).executedPlan, tableName = cd.cachedRepresentation.tableName, statsOfPlanToCache = cd.plan.stats) needToRecache += cd.copy(cachedRepresentation = newCache) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index e0561ee2797a5..f6c760ed492a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.storage.StorageLevel @@ -96,4 +97,19 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { agged.unpersist() assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.") } + + test("SPARK-24613 Cache with UDF could not be matched with subsequent dependent caches") { + val udf1 = udf({x: Int => x + 1}) + val df = spark.range(0, 10).toDF("a").withColumn("b", udf1($"a")) + val df2 = df.agg(sum(df("b"))) + + df.cache() + df.count() + df2.cache() + + val plan = df2.queryExecution.withCachedData + assert(plan.isInstanceOf[InMemoryRelation]) + val internalPlan = plan.asInstanceOf[InMemoryRelation].child + assert(internalPlan.find(_.isInstanceOf[InMemoryTableScanExec]).isDefined) + } }