Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24867] [SQL] Add AnalysisBarrier to DataFrameWriter #21821

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options)
if (writer.isPresent) {
runCommand(df.sparkSession, "save") {
WriteToDataSourceV2(writer.get(), df.logicalPlan)
WriteToDataSourceV2(writer.get(), df.planWithBarrier)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not needed but it is safe to have.

}
}

Expand All @@ -275,7 +275,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
sparkSession = df.sparkSession,
className = source,
partitionColumns = partitioningColumns.getOrElse(Nil),
options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan))
options = extraOptions.toMap).planForWriting(mode, df.planWithBarrier)
}
}

Expand Down Expand Up @@ -323,7 +323,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
InsertIntoTable(
table = UnresolvedRelation(tableIdent),
partition = Map.empty[String, Option[String]],
query = df.logicalPlan,
query = df.planWithBarrier,
overwrite = mode == SaveMode.Overwrite,
ifPartitionNotExists = false)
}
Expand Down Expand Up @@ -459,7 +459,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
partitionColumnNames = partitioningColumns.getOrElse(Nil),
bucketSpec = getBucketSpec)

runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan)))
runCommand(df.sparkSession, "saveAsTable") {
CreateTable(tableDesc, mode, Some(df.planWithBarrier))
}
}

/**
Expand Down
45 changes: 43 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@ package org.apache.spark.sql

import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand}
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.functions.{lit, udf}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types.{DataTypes, DoubleType}
import org.apache.spark.sql.util.QueryExecutionListener

import scala.collection.mutable.ArrayBuffer

private case class FunctionResult(f1: String, f2: String)

Expand Down Expand Up @@ -324,4 +330,39 @@ class UDFSuite extends QueryTest with SharedSQLContext {
assert(outputStream.toString.contains("UDF:f(a._1 AS `_1`)"))
}
}

test("cached Data should be used in the write path") {
withTable("t") {
withTempPath { path =>
var numTotalCachedHit = 0
val listener = new QueryExecutionListener {
override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {}

override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
qe.withCachedData match {
case c: CreateDataSourceTableAsSelectCommand
if c.query.isInstanceOf[InMemoryRelation] =>
numTotalCachedHit += 1
case i: InsertIntoHadoopFsRelationCommand
if i.query.isInstanceOf[InMemoryRelation] =>
numTotalCachedHit += 1
case _ =>
}
}
}
spark.listenerManager.register(listener)

val udf1 = udf({ (x: Int, y: Int) => x + y })
val df = spark.range(0, 3).toDF("a")
.withColumn("b", udf1($"a", lit(10)))
df.cache()
df.write.saveAsTable("t")
assert(numTotalCachedHit == 1, "expected to be cached in saveAsTable")
df.write.insertInto("t")
assert(numTotalCachedHit == 2, "expected to be cached in insertInto")
df.write.save(path.getCanonicalPath)
assert(numTotalCachedHit == 3, "expected to be cached in save for native")
}
}
}
}