Skip to content

Commit

Permalink
Review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
szehon-ho committed Sep 13, 2024
1 parent 4b89b74 commit 603f6c9
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 13 deletions.
48 changes: 39 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, FileTable}
import org.apache.spark.sql.execution.python.EvaluatePython
import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -4136,27 +4137,41 @@ class Dataset[T] private[sql](
new MergeIntoWriter[T](table, this, condition)
}

/**
* Update rows in a table.
*
* Scala Example:
* {{{
* spark.table("source")
* .update(Map("salary" -> lit(200)))
* .execute()
*
* }}}
* @param assignments A Map of column names to Column expressions representing the updates
* to be applied.
* @since 4.0.0
*/
def update(assignments: Map[String, Column]): Unit = {
updateInternal(assignments)
}

/**
* Update rows in a table that match a condition.
*
* Scala Example:
* {{{
* spark.table("source").update(Map("salary" -> lit(200)))
* .where($"salary" === 100)
* spark.table("source")
* .update(Map("salary" -> lit(200)), $"salary" === 100)
* .execute()
*
* }}}
* @param assignments A Map of column names to Column expressions representing the updates
* to be applied.
* @param condition the update condition
* @since 4.0.0
*/
def update(assignments: Map[String, Column]): UpdateWriter[T] = {
if (isStreaming) {
throw new AnalysisException(
errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED",
messageParameters = Map("methodName" -> toSQLId("update")))
}
new UpdateWriter[T](this, assignments)
def update(assignments: Map[String, Column], condition: Column): Unit = {
updateInternal(assignments, Some(condition))
}

/**
Expand Down Expand Up @@ -4535,4 +4550,19 @@ class Dataset[T] private[sql](
private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = {
toArrowBatchRdd(queryExecution.executedPlan)
}

private def updateInternal(assignments: Map[String, Column],
condition: Option[Column] = None): Unit = {
if (isStreaming) {
throw new AnalysisException(
errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED",
messageParameters = Map("methodName" -> toSQLId("update")))
}
val update = UpdateTable(
logicalPlan,
assignments.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq,
condition.map(_.expr))
val qe = sparkSession.sessionState.executePlan(update)
qe.assertCommandExecuted()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ class UpdateDataFrameSuite extends RowLevelOperationSuiteBase {
|""".stripMargin)

spark.table(tableNameAsString)
.update(Map("salary" -> lit(-1)))
.where($"pk" >= 2)
.execute()
.update(Map("salary" -> lit(-1)), $"pk" >= 2)

checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Expand All @@ -53,7 +51,6 @@ class UpdateDataFrameSuite extends RowLevelOperationSuiteBase {

spark.table(tableNameAsString)
.update(Map("dep" -> lit("software")))
.execute()

checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Expand Down

0 comments on commit 603f6c9

Please sign in to comment.