diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 7a5d84a0cb..d1fc616359 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -88,6 +88,13 @@ class RawDeltaTable: def create_checkpoint(self) -> None: ... def get_add_actions(self, flatten: bool) -> pa.RecordBatch: ... def delete(self, predicate: Optional[str]) -> str: ... + def update( + self, + updates: Dict[str, str], + predicate: Optional[str], + writer_properties: Optional[Dict[str, int]], + safe_cast: bool = False, + ) -> str: ... def get_active_partitions( self, partitions_filters: Optional[FilterType] = None ) -> Any: ... diff --git a/python/deltalake/table.py b/python/deltalake/table.py index aa91f33b15..80a48f619e 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -473,6 +473,52 @@ def vacuum( enforce_retention_duration, ) + def update( + self, + updates: Dict[str, str], + predicate: Optional[str] = None, + writer_properties: Optional[Dict[str, int]] = None, + error_on_type_mismatch: bool = True, + ) -> Dict[str, Any]: + """UPDATE records in the Delta Table that matches an optional predicate. + + :param updates: a mapping of column name to update SQL expression. + :param predicate: a logical expression, defaults to None + :writer_properties: Pass writer properties to the Rust parquet writer, see options https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html, + only the fields: data_page_size_limit, dictionary_page_size_limit, data_page_row_count_limit, write_batch_size, max_row_group_size are supported. + :error_on_type_mismatch: specify if merge will return error if data types are mismatching :default = True + :return: the metrics from delete + + Examples: + + Update some row values with SQL predicate. This is equivalent to + ``UPDATE table SET deleted = true WHERE id = '5'`` + + >>> from deltalake import DeltaTable + >>> dt = DeltaTable("tmp") + >>> dt.update(predicate="id = '5'", + ... updates = { + ... "deleted": True, + ... } + ... ) + + Update all row values. This is equivalent to + ``UPDATE table SET id = concat(id, '_old')``. + >>> from deltalake import DeltaTable + >>> dt = DeltaTable("tmp") + >>> dt.update(updates = { + ... "deleted": True, + ... "id": "concat(id, '_old')" + ... } + ... ) + + """ + + metrics = self._table.update( + updates, predicate, writer_properties, safe_cast=not error_on_type_mismatch + ) + return json.loads(metrics) + @property def optimize( self, diff --git a/python/docs/source/usage.rst b/python/docs/source/usage.rst index 6e7e806f4d..569546dce8 100644 --- a/python/docs/source/usage.rst +++ b/python/docs/source/usage.rst @@ -484,6 +484,53 @@ the data passed to it differs from the existing table's schema. If you wish to alter the schema as part of an overwrite pass in ``overwrite_schema=True``. +Updating Delta Tables +--------------------- + +.. py:currentmodule:: deltalake.table + +Row values in an existing delta table can be updated with the :meth:`DeltaTable.update` command. A update +dictionary has to be passed, where they key is the column you wish to update, and the value is a +Expression in string format. + +Update all the rows for the column "processed" to the value True. + +.. code-block:: python + + >>> from deltalake import write_deltalake, DeltaTable + >>> df = pd.DataFrame({'x': [1, 2, 3], 'deleted': [False, False, False]}) + >>> write_deltalake('path/to/table', df) + >>> dt = DeltaTable('path/to/table') + >>> dt.update({"processed": "True"}) + >>> dt.to_pandas() + >>> x processed + 0 1 True + 1 2 True + 2 3 True +.. note:: + :meth:`DeltaTable.update` predicates and updates are all in string format. The predicates and expressions, + are parsed into Apache Datafusion expressions. + +Apply a soft deletion based on a predicate, so update all the rows for the column "deleted" to the value +True where x = 3 + +.. code-block:: python + + >>> from deltalake import write_deltalake, DeltaTable + >>> df = pd.DataFrame({'x': [1, 2, 3], 'deleted': [False, False, False]}) + >>> write_deltalake('path/to/table', df) + >>> dt = DeltaTable('path/to/table') + >>> dt.update( + ... updates={"deleted": "True"}, + ... predicate= 'x = 3', + ... ) + >>> dt.to_pandas() + >>> x deleted + 0 1 False + 1 2 False + 2 3 True + + Overwriting a partition ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/python/src/lib.rs b/python/src/lib.rs index 6118b4d81d..72be5f63c3 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -25,7 +25,9 @@ use deltalake::operations::delete::DeleteBuilder; use deltalake::operations::optimize::{OptimizeBuilder, OptimizeType}; use deltalake::operations::restore::RestoreBuilder; use deltalake::operations::transaction::commit; +use deltalake::operations::update::UpdateBuilder; use deltalake::operations::vacuum::VacuumBuilder; +use deltalake::parquet::file::properties::WriterProperties; use deltalake::partitions::PartitionFilter; use deltalake::protocol::{ self, Action, ColumnCountStat, ColumnValueStat, DeltaOperation, SaveMode, Stats, @@ -280,6 +282,59 @@ impl RawDeltaTable { Ok(metrics.files_deleted) } + /// Run the UPDATE command on the Delta Table + #[pyo3(signature = (updates, predicate=None, writer_properties=None, safe_cast = false))] + pub fn update( + &mut self, + updates: HashMap, + predicate: Option, + writer_properties: Option>, + safe_cast: bool, + ) -> PyResult { + let mut cmd = UpdateBuilder::new(self._table.object_store(), self._table.state.clone()) + .with_safe_cast(safe_cast); + + if let Some(writer_props) = writer_properties { + let mut properties = WriterProperties::builder(); + let data_page_size_limit = writer_props.get("data_page_size_limit"); + let dictionary_page_size_limit = writer_props.get("dictionary_page_size_limit"); + let data_page_row_count_limit = writer_props.get("data_page_row_count_limit"); + let write_batch_size = writer_props.get("write_batch_size"); + let max_row_group_size = writer_props.get("max_row_group_size"); + + if let Some(data_page_size) = data_page_size_limit { + properties = properties.set_data_page_size_limit(*data_page_size); + } + if let Some(dictionary_page_size) = dictionary_page_size_limit { + properties = properties.set_dictionary_page_size_limit(*dictionary_page_size); + } + if let Some(data_page_row_count) = data_page_row_count_limit { + properties = properties.set_data_page_row_count_limit(*data_page_row_count); + } + if let Some(batch_size) = write_batch_size { + properties = properties.set_write_batch_size(*batch_size); + } + if let Some(row_group_size) = max_row_group_size { + properties = properties.set_max_row_group_size(*row_group_size); + } + cmd = cmd.with_writer_properties(properties.build()); + } + + for (col_name, expression) in updates { + cmd = cmd.with_update(col_name.clone(), expression.clone()); + } + + if let Some(update_predicate) = predicate { + cmd = cmd.with_predicate(update_predicate); + } + + let (table, metrics) = rt()? + .block_on(cmd.into_future()) + .map_err(PythonError::from)?; + self._table.state = table.state; + Ok(serde_json::to_string(&metrics).unwrap()) + } + /// Run the optimize command on the Delta Table: merge small files into a large file by bin-packing. #[pyo3(signature = (partition_filters = None, target_size = None, max_concurrent_tasks = None, min_commit_interval = None))] pub fn compact_optimize( diff --git a/python/tests/test_update.py b/python/tests/test_update.py new file mode 100644 index 0000000000..defdd1a396 --- /dev/null +++ b/python/tests/test_update.py @@ -0,0 +1,109 @@ +import pathlib + +import pyarrow as pa +import pytest + +from deltalake import DeltaTable, write_deltalake + + +@pytest.fixture() +def sample_table(): + nrows = 5 + return pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array(list(range(nrows)), pa.int64()), + "sold": pa.array(list(range(nrows)), pa.int64()), + "deleted": pa.array([False] * nrows), + } + ) + + +def test_update_with_predicate(tmp_path: pathlib.Path, sample_table: pa.Table): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + nrows = 5 + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array(list(range(nrows)), pa.int64()), + "sold": pa.array(list(range(nrows)), pa.int64()), + "deleted": pa.array([False, False, False, False, True]), + } + ) + + dt.update(updates={"deleted": "True"}, predicate="price > 3") + + result = dt.to_pyarrow_table() + last_action = dt.history(1)[0] + + assert last_action["operation"] == "UPDATE" + assert result == expected + + +def test_update_wo_predicate(tmp_path: pathlib.Path, sample_table: pa.Table): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + nrows = 5 + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array(list(range(nrows)), pa.int64()), + "sold": pa.array(list(range(nrows)), pa.int64()), + "deleted": pa.array([True] * 5), + } + ) + + dt.update(updates={"deleted": "True"}) + + result = dt.to_pyarrow_table() + last_action = dt.history(1)[0] + + assert last_action["operation"] == "UPDATE" + assert result == expected + + +def test_update_wrong_types_cast(tmp_path: pathlib.Path, sample_table: pa.Table): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + with pytest.raises(Exception) as excinfo: + dt.update(updates={"deleted": "'hello_world'"}) + + assert ( + str(excinfo.value) + == "Cast error: Cannot cast value 'hello_world' to value of Boolean type" + ) + + +def test_update_wo_predicate_multiple_updates( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + expected = pa.table( + { + "id": pa.array(["1_1", "2_1", "3_1", "4_1", "5_1"]), + "price": pa.array([0, 1, 2, 3, 4], pa.int64()), + "sold": pa.array([0, 1, 4, 9, 16], pa.int64()), + "deleted": pa.array([True] * 5), + } + ) + + dt.update( + updates={"deleted": "True", "sold": "sold * price", "id": "concat(id,'_1')"}, + error_on_type_mismatch=False, + ) + + result = dt.to_pyarrow_table() + last_action = dt.history(1)[0] + + assert last_action["operation"] == "UPDATE" + assert result == expected diff --git a/rust/src/operations/update.rs b/rust/src/operations/update.rs index 3891c04fd9..a8b2820c33 100644 --- a/rust/src/operations/update.rs +++ b/rust/src/operations/update.rs @@ -40,6 +40,7 @@ use datafusion_physical_expr::{ }; use futures::future::BoxFuture; use parquet::file::properties::WriterProperties; +use serde::Serialize; use serde_json::{Map, Value}; use crate::{ @@ -80,7 +81,7 @@ pub struct UpdateBuilder { safe_cast: bool, } -#[derive(Default)] +#[derive(Default, Serialize)] /// Metrics collected during the Update operation pub struct UpdateMetrics { /// Number of files added.