diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b9fa43f1f9fbd..39c0e102b69b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -254,7 +254,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options) if (writer.isPresent) { runCommand(df.sparkSession, "save") { - WriteToDataSourceV2(writer.get(), df.logicalPlan) + WriteToDataSourceV2(writer.get(), df.planWithBarrier) } } @@ -275,7 +275,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { sparkSession = df.sparkSession, className = source, partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) + options = extraOptions.toMap).planForWriting(mode, df.planWithBarrier) } } @@ -323,7 +323,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { InsertIntoTable( table = UnresolvedRelation(tableIdent), partition = Map.empty[String, Option[String]], - query = df.logicalPlan, + query = df.planWithBarrier, overwrite = mode == SaveMode.Overwrite, ifPartitionNotExists = false) } @@ -459,7 +459,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitionColumnNames = partitioningColumns.getOrElse(Nil), bucketSpec = getBucketSpec) - runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan))) + runCommand(df.sparkSession, "saveAsTable") { + CreateTable(tableDesc, mode, Some(df.planWithBarrier)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 04bf8c6dd917f..c7f7e4d755cfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver} +import org.apache.spark.sql.catalyst.analysis.{EliminateBarriers, NoSuchTableException, Resolver} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} @@ -891,8 +891,9 @@ object DDLUtils { * Throws exception if outputPath tries to overwrite inputpath. */ def verifyNotReadPath(query: LogicalPlan, outputPath: Path) : Unit = { - val inputPaths = query.collect { - case LogicalRelation(r: HadoopFsRelation, _, _, _) => r.location.rootPaths + val inputPaths = EliminateBarriers(query).collect { + case LogicalRelation(r: HadoopFsRelation, _, _, _) => + r.location.rootPaths }.flatten if (inputPaths.contains(outputPath)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index d8074571ffc65..30dca9497ddde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -19,11 +19,16 @@ package org.apache.spark.sql import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.plans.logical.Project -import org.apache.spark.sql.execution.command.ExplainCommand +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand} +import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand import org.apache.spark.sql.functions.{lit, udf} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types.{DataTypes, DoubleType} +import org.apache.spark.sql.util.QueryExecutionListener + private case class FunctionResult(f1: String, f2: String) @@ -325,6 +330,41 @@ class UDFSuite extends QueryTest with SharedSQLContext { } } + test("cached Data should be used in the write path") { + withTable("t") { + withTempPath { path => + var numTotalCachedHit = 0 + val listener = new QueryExecutionListener { + override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {} + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + qe.withCachedData match { + case c: CreateDataSourceTableAsSelectCommand + if c.query.isInstanceOf[InMemoryRelation] => + numTotalCachedHit += 1 + case i: InsertIntoHadoopFsRelationCommand + if i.query.isInstanceOf[InMemoryRelation] => + numTotalCachedHit += 1 + case _ => + } + } + } + spark.listenerManager.register(listener) + + val udf1 = udf({ (x: Int, y: Int) => x + y }) + val df = spark.range(0, 3).toDF("a") + .withColumn("b", udf1($"a", lit(10))) + df.cache() + df.write.saveAsTable("t") + assert(numTotalCachedHit == 1, "expected to be cached in saveAsTable") + df.write.insertInto("t") + assert(numTotalCachedHit == 2, "expected to be cached in insertInto") + df.write.save(path.getCanonicalPath) + assert(numTotalCachedHit == 3, "expected to be cached in save for native") + } + } + } + test("SPARK-24891 Fix HandleNullInputsForUDF rule") { val udf1 = udf({(x: Int, y: Int) => x + y}) val df = spark.range(0, 3).toDF("a")