From 9848ffb32408bde2ac8427718afe362e1c97e13c Mon Sep 17 00:00:00 2001 From: "R. Tyler Croy" Date: Fri, 3 Jan 2025 02:10:20 +0000 Subject: [PATCH] feat: allow multiple Python threads to work with a single DeltaTable instance This change introduces an internal Mutex inside of RawDeltaTable which allows the PyO3 bindings to share the Python object between threads at the Python layer. PyO3 will raise a `RuntimeError: Already borrowed` for any function call which takes a mutable reference to `self`. Introducing the internal Mutex ensures that all function signatures can operate with just self-references safely. The Rust-level Mutex is a simple passthrough for most operations which do not need to modify the underlying state. The critical sections which typically need to acquire and mutate with a lock are after I/O bound operations are completed as far as I can tell, so I don't anticipate deadlock or performance issues. There is still some cleanup of errors that needs to happen to make the code here more ergonomic when blending DeltaError with PoisonError from the lock, as such right now there's a lot of ugly error mapping. Fixes #2958 Signed-off-by: R. Tyler Croy Sponsored-by: Neuralink Corp. --- python/Cargo.toml | 3 +- python/src/error.rs | 10 + python/src/filesystem.rs | 4 +- python/src/lib.rs | 636 ++++++++++++++++++++-------------- python/src/query.rs | 8 +- python/tests/test_threaded.py | 59 ++++ python/tests/test_writer.py | 34 -- 7 files changed, 460 insertions(+), 294 deletions(-) create mode 100644 python/tests/test_threaded.py diff --git a/python/Cargo.toml b/python/Cargo.toml index 415dbc4b0b..e0a0f1f86b 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "deltalake-python" -version = "0.23.1" +version = "0.23.2" authors = ["Qingping Hou ", "Will Jones "] homepage = "https://github.com/delta-io/delta-rs" license = "Apache-2.0" @@ -33,6 +33,7 @@ env_logger = "0" lazy_static = "1" regex = { workspace = true } thiserror = { workspace = true } +tracing = { workspace = true } # runtime futures = { workspace = true } diff --git a/python/src/error.rs b/python/src/error.rs index b1d22fc7ca..0fbd7ffb12 100644 --- a/python/src/error.rs +++ b/python/src/error.rs @@ -2,6 +2,7 @@ use arrow_schema::ArrowError; use deltalake::datafusion::error::DataFusionError; use deltalake::protocol::ProtocolError; use deltalake::{errors::DeltaTableError, ObjectStoreError}; +use pyo3::exceptions::PyRuntimeError; use pyo3::exceptions::{ PyException, PyFileNotFoundError, PyIOError, PyNotImplementedError, PyValueError, }; @@ -96,6 +97,14 @@ pub enum PythonError { Protocol(#[from] ProtocolError), #[error("Error in data fusion")] DataFusion(#[from] DataFusionError), + #[error("Lock acquisition error")] + ThreadingError(String), +} + +impl Into for std::sync::PoisonError { + fn into(self) -> PythonError { + PythonError::ThreadingError(self.to_string()) + } } impl From for pyo3::PyErr { @@ -106,6 +115,7 @@ impl From for pyo3::PyErr { PythonError::Arrow(err) => arrow_to_py(err), PythonError::Protocol(err) => checkpoint_to_py(err), PythonError::DataFusion(err) => datafusion_to_py(err), + PythonError::ThreadingError(err) => PyRuntimeError::new_err(err), } } } diff --git a/python/src/filesystem.rs b/python/src/filesystem.rs index ee5261ab09..b88bff7877 100644 --- a/python/src/filesystem.rs +++ b/python/src/filesystem.rs @@ -71,11 +71,11 @@ impl DeltaFileSystemHandler { options: Option>, known_sizes: Option>, ) -> PyResult { - let storage = table._table.object_store(); + let storage = table.object_store()?; Ok(Self { inner: storage, config: FsConfig { - root_url: table._table.table_uri(), + root_url: table.with_table(|t| Ok(t.table_uri()))?, options: options.unwrap_or_default(), }, known_sizes, diff --git a/python/src/lib.rs b/python/src/lib.rs index 259b3d8a05..7c86aeec9e 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -10,7 +10,7 @@ use std::collections::{HashMap, HashSet}; use std::ffi::CString; use std::future::IntoFuture; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::time; use std::time::{SystemTime, UNIX_EPOCH}; @@ -31,6 +31,7 @@ use deltalake::errors::DeltaTableError; use deltalake::kernel::{ scalars::ScalarExt, Action, Add, Invariant, LogicalFile, Remove, StructType, Transaction, }; +use deltalake::logstore::LogStoreRef; use deltalake::operations::add_column::AddColumnBuilder; use deltalake::operations::add_feature::AddTableFeatureBuilder; use deltalake::operations::collect_sendable_stream; @@ -53,11 +54,13 @@ use deltalake::parquet::errors::ParquetError; use deltalake::parquet::file::properties::WriterProperties; use deltalake::partitions::PartitionFilter; use deltalake::protocol::{DeltaOperation, SaveMode}; -use deltalake::storage::IORuntime; +use deltalake::storage::{IORuntime, ObjectStoreRef}; +use deltalake::table::state::DeltaTableState; use deltalake::DeltaTableBuilder; use deltalake::{DeltaOps, DeltaResult}; use error::DeltaError; use futures::future::join_all; +use tracing::log::*; use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; @@ -82,7 +85,9 @@ enum PartitionFilterValue { #[pyclass(module = "deltalake._internal")] struct RawDeltaTable { - _table: deltalake::DeltaTable, + /// The internal reference to the table is guarded by a Mutex to allow for re-using the same + /// [DeltaTable] instance across multiple Python threads + _table: Arc>, // storing the config additionally on the table helps us make pickling work. _config: FsConfig, } @@ -105,6 +110,52 @@ struct RawDeltaTableMetaData { type StringVec = Vec; +/// Segmented impl for RawDeltaTable to avoid these methods being exposed via the pymethods macro. +/// +/// In essence all these functions should be considered internal to the Rust code and not exposed +/// up to the Python layer +impl RawDeltaTable { + /// Internal helper method which allows for acquiring the lock on the underlying + /// [deltalake::DeltaTable] and then executing the given function parameter with the guarded + /// reference + /// + /// This should only be used for read-only accesses (duh) and callers that need to modify the + /// underlying instance should acquire the lock themselves. + /// + fn with_table(&self, func: impl Fn(&deltalake::DeltaTable) -> PyResult) -> PyResult { + match self._table.lock() { + Ok(table) => func(&table), + Err(e) => Err(PyRuntimeError::new_err(e.to_string())), + } + } + + fn object_store(&self) -> PyResult { + self.with_table(|t| Ok(t.object_store().clone())) + } + + fn cloned_state(&self) -> PyResult { + self.with_table(|t| { + t.snapshot() + .cloned() + .map_err(PythonError::from) + .map_err(PyErr::from) + }) + } + + fn log_store(&self) -> PyResult { + self.with_table(|t| Ok(t.log_store().clone())) + } + + fn set_state(&self, state: Option) -> PyResult<()> { + let mut original = self + ._table + .lock() + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + (*original).state = state; + Ok(()) + } +} + #[pymethods] impl RawDeltaTable { #[new] @@ -138,7 +189,7 @@ impl RawDeltaTable { let table = rt().block_on(builder.load()).map_err(PythonError::from)?; Ok(RawDeltaTable { - _table: table, + _table: Arc::new(Mutex::new(table)), _config: FsConfig { root_url: table_uri.into(), options, @@ -168,19 +219,24 @@ impl RawDeltaTable { } pub fn table_uri(&self) -> PyResult { - Ok(self._table.table_uri()) + self.with_table(|t| Ok(t.table_uri())) } pub fn version(&self) -> PyResult { - Ok(self._table.version()) + self.with_table(|t| Ok(t.version())) } pub fn has_files(&self) -> PyResult { - Ok(self._table.config.require_files) + self.with_table(|t| Ok(t.config.require_files)) } pub fn metadata(&self) -> PyResult { - let metadata = self._table.metadata().map_err(PythonError::from)?; + let metadata = self.with_table(|t| { + t.metadata() + .cloned() + .map_err(PythonError::from) + .map_err(PyErr::from) + })?; Ok(RawDeltaTableMetaData { id: metadata.id.clone(), name: metadata.name.clone(), @@ -192,7 +248,12 @@ impl RawDeltaTable { } pub fn protocol_versions(&self) -> PyResult<(i32, i32, Option, Option)> { - let table_protocol = self._table.protocol().map_err(PythonError::from)?; + let table_protocol = self.with_table(|t| { + t.protocol() + .cloned() + .map_err(PythonError::from) + .map_err(PyErr::from) + })?; Ok(( table_protocol.min_reader_version, table_protocol.min_writer_version, @@ -225,67 +286,105 @@ impl RawDeltaTable { pub fn check_can_write_timestamp_ntz(&self, schema: PyArrowType) -> PyResult<()> { let schema: StructType = (&schema.0).try_into().map_err(PythonError::from)?; - Ok(PROTOCOL - .check_can_write_timestamp_ntz( - self._table.snapshot().map_err(PythonError::from)?, - &schema, - ) - .map_err(|e| DeltaTableError::Generic(e.to_string())) - .map_err(PythonError::from)?) + // Need to unlock to access the shared reference to &DeltaTableState + match self._table.lock() { + Ok(table) => Ok(PROTOCOL + .check_can_write_timestamp_ntz( + table.snapshot().map_err(PythonError::from)?, + &schema, + ) + .map_err(|e| DeltaTableError::Generic(e.to_string())) + .map_err(PythonError::from)?), + Err(e) => Err(PyRuntimeError::new_err(e.to_string())), + } } - pub fn load_version(&mut self, py: Python, version: i64) -> PyResult<()> { + /// Load the internal [RawDeltaTable] with the table state from the specified `version` + /// + /// This will acquire the internal lock since it is a mutating operation! + pub fn load_version(&self, py: Python, version: i64) -> PyResult<()> { py.allow_threads(|| { - Ok(rt() - .block_on(self._table.load_version(version)) - .map_err(PythonError::from)?) + Ok(rt().block_on(async { + let mut table = self + ._table + .lock() + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + (*table) + .load_version(version) + .await + .map_err(PythonError::from) + .map_err(PyErr::from) + })?) }) } - pub fn get_latest_version(&mut self, py: Python) -> PyResult { + /// Retrieve the latest version from the internally loaded table state + pub fn get_latest_version(&self, py: Python) -> PyResult { py.allow_threads(|| { - Ok(rt() - .block_on(self._table.get_latest_version()) - .map_err(PythonError::from)?) + Ok(rt().block_on(async { + match self._table.lock() { + Ok(table) => table + .get_latest_version() + .await + .map_err(PythonError::from) + .map_err(PyErr::from), + Err(e) => Err(PyRuntimeError::new_err(e.to_string())), + } + })?) }) } - pub fn get_earliest_version(&mut self, py: Python) -> PyResult { + pub fn get_earliest_version(&self, py: Python) -> PyResult { py.allow_threads(|| { - Ok(rt() - .block_on(self._table.get_earliest_version()) - .map_err(PythonError::from)?) + Ok(rt().block_on(async { + match self._table.lock() { + Ok(table) => table + .get_earliest_version() + .await + .map_err(PythonError::from) + .map_err(PyErr::from), + Err(e) => Err(PyRuntimeError::new_err(e.to_string())), + } + })?) }) } - pub fn get_num_index_cols(&mut self) -> PyResult { - Ok(self - ._table - .snapshot() - .map_err(PythonError::from)? - .config() - .num_indexed_cols()) + pub fn get_num_index_cols(&self) -> PyResult { + self.with_table(|t| { + Ok(t.snapshot() + .map_err(PythonError::from)? + .config() + .num_indexed_cols()) + }) } - pub fn get_stats_columns(&mut self) -> PyResult>> { - Ok(self - ._table - .snapshot() - .map_err(PythonError::from)? - .config() - .stats_columns() - .map(|v| v.iter().map(|v| v.to_string()).collect::>())) + pub fn get_stats_columns(&self) -> PyResult>> { + self.with_table(|t| { + Ok(t.snapshot() + .map_err(PythonError::from)? + .config() + .stats_columns() + .map(|v| v.iter().map(|s| s.to_string()).collect::>())) + }) } - pub fn load_with_datetime(&mut self, py: Python, ds: &str) -> PyResult<()> { + pub fn load_with_datetime(&self, py: Python, ds: &str) -> PyResult<()> { py.allow_threads(|| { let datetime = DateTime::::from(DateTime::::parse_from_rfc3339(ds).map_err( |err| PyValueError::new_err(format!("Failed to parse datetime string: {err}")), )?); - Ok(rt() - .block_on(self._table.load_with_datetime(datetime)) - .map_err(PythonError::from)?) + Ok(rt().block_on(async { + let mut table = self + ._table + .lock() + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + (*table) + .load_with_datetime(datetime) + .await + .map_err(PythonError::from) + .map_err(PyErr::from) + })?) }) } @@ -302,19 +401,23 @@ impl RawDeltaTable { if let Some(filters) = partition_filters { let filters = convert_partition_filters(filters).map_err(PythonError::from)?; Ok(self - ._table - .get_files_by_partitions(&filters) - .map_err(PythonError::from)? + .with_table(|t| { + t.get_files_by_partitions(&filters) + .map_err(PythonError::from) + .map_err(PyErr::from) + })? .into_iter() .map(|p| p.to_string()) .collect()) } else { - Ok(self - ._table - .get_files_iter() - .map_err(PythonError::from)? - .map(|f| f.to_string()) - .collect()) + match self._table.lock() { + Ok(table) => Ok(table + .get_files_iter() + .map_err(PythonError::from)? + .map(|f| f.to_string()) + .collect()), + Err(e) => Err(PyRuntimeError::new_err(e.to_string())), + } } }) } @@ -324,36 +427,43 @@ impl RawDeltaTable { &self, partition_filters: Option>, ) -> PyResult> { - if !self._table.config.require_files { + if !self.with_table(|t| Ok(t.config.require_files))? { return Err(DeltaError::new_err("Table is initiated without files.")); } if let Some(filters) = partition_filters { let filters = convert_partition_filters(filters).map_err(PythonError::from)?; - Ok(self - ._table - .get_file_uris_by_partitions(&filters) - .map_err(PythonError::from)?) + self.with_table(|t| { + t.get_file_uris_by_partitions(&filters) + .map_err(PythonError::from) + .map_err(PyErr::from) + }) } else { - Ok(self - ._table - .get_file_uris() - .map_err(PythonError::from)? - .collect()) + self.with_table(|t| { + Ok(t.get_file_uris() + .map_err(PythonError::from) + .map_err(PyErr::from)? + .collect::>()) + }) } } #[getter] pub fn schema<'py>(&self, py: Python<'py>) -> PyResult> { - let schema: &StructType = self._table.get_schema().map_err(PythonError::from)?; - schema_to_pyobject(schema.to_owned(), py) + let schema: StructType = self.with_table(|t| { + t.get_schema() + .map_err(PythonError::from) + .map_err(PyErr::from) + .map(|s| s.to_owned()) + })?; + schema_to_pyobject(schema, py) } /// Run the Vacuum command on the Delta Table: list and delete files no longer referenced /// by the Delta table and are older than the retention threshold. #[pyo3(signature = (dry_run, retention_hours = None, enforce_retention_duration = true, commit_properties=None, post_commithook_properties=None))] pub fn vacuum( - &mut self, + &self, py: Python, dry_run: bool, retention_hours: Option, @@ -362,12 +472,17 @@ impl RawDeltaTable { post_commithook_properties: Option, ) -> PyResult> { let (table, metrics) = py.allow_threads(|| { - let mut cmd = VacuumBuilder::new( - self._table.log_store(), - self._table.snapshot().map_err(PythonError::from)?.clone(), - ) - .with_enforce_retention_duration(enforce_retention_duration) - .with_dry_run(dry_run); + let snapshot = match self._table.lock() { + Ok(table) => table + .snapshot() + .cloned() + .map_err(PythonError::from) + .map_err(PyErr::from), + Err(e) => Err(PyRuntimeError::new_err(e.to_string())), + }?; + let mut cmd = VacuumBuilder::new(self.log_store()?, snapshot) + .with_enforce_retention_duration(enforce_retention_duration) + .with_dry_run(dry_run); if let Some(retention_period) = retention_hours { cmd = cmd.with_retention_period(Duration::hours(retention_period as i64)); } @@ -377,9 +492,11 @@ impl RawDeltaTable { { cmd = cmd.with_commit_properties(commit_properties); } - rt().block_on(cmd.into_future()).map_err(PythonError::from) + rt().block_on(cmd.into_future()) + .map_err(PythonError::from) + .map_err(PyErr::from) })?; - self._table.state = table.state; + self.set_state(table.state)?; Ok(metrics.files_deleted) } @@ -387,7 +504,7 @@ impl RawDeltaTable { #[pyo3(signature = (updates, predicate=None, writer_properties=None, safe_cast = false, commit_properties = None, post_commithook_properties=None))] #[allow(clippy::too_many_arguments)] pub fn update( - &mut self, + &self, py: Python, updates: HashMap, predicate: Option, @@ -397,11 +514,8 @@ impl RawDeltaTable { post_commithook_properties: Option, ) -> PyResult { let (table, metrics) = py.allow_threads(|| { - let mut cmd = UpdateBuilder::new( - self._table.log_store(), - self._table.snapshot().map_err(PythonError::from)?.clone(), - ) - .with_safe_cast(safe_cast); + let mut cmd = UpdateBuilder::new(self.log_store()?, self.cloned_state()?) + .with_safe_cast(safe_cast); if let Some(writer_props) = writer_properties { cmd = cmd.with_writer_properties( @@ -423,9 +537,11 @@ impl RawDeltaTable { cmd = cmd.with_commit_properties(commit_properties); } - rt().block_on(cmd.into_future()).map_err(PythonError::from) + rt().block_on(cmd.into_future()) + .map_err(PythonError::from) + .map_err(PyErr::from) })?; - self._table.state = table.state; + self.set_state(table.state)?; Ok(serde_json::to_string(&metrics).unwrap()) } @@ -442,7 +558,7 @@ impl RawDeltaTable { ))] #[allow(clippy::too_many_arguments)] pub fn compact_optimize( - &mut self, + &self, py: Python, partition_filters: Option>, target_size: Option, @@ -453,11 +569,8 @@ impl RawDeltaTable { post_commithook_properties: Option, ) -> PyResult { let (table, metrics) = py.allow_threads(|| { - let mut cmd = OptimizeBuilder::new( - self._table.log_store(), - self._table.snapshot().map_err(PythonError::from)?.clone(), - ) - .with_max_concurrent_tasks(max_concurrent_tasks.unwrap_or_else(num_cpus::get)); + let mut cmd = OptimizeBuilder::new(self.log_store()?, self.cloned_state()?) + .with_max_concurrent_tasks(max_concurrent_tasks.unwrap_or_else(num_cpus::get)); if let Some(size) = target_size { cmd = cmd.with_target_size(size); } @@ -482,9 +595,11 @@ impl RawDeltaTable { .map_err(PythonError::from)?; cmd = cmd.with_filters(&converted_filters); - rt().block_on(cmd.into_future()).map_err(PythonError::from) + rt().block_on(cmd.into_future()) + .map_err(PythonError::from) + .map_err(PyErr::from) })?; - self._table.state = table.state; + self.set_state(table.state)?; Ok(serde_json::to_string(&metrics).unwrap()) } @@ -500,7 +615,7 @@ impl RawDeltaTable { commit_properties=None, post_commithook_properties=None))] pub fn z_order_optimize( - &mut self, + &self, py: Python, z_order_columns: Vec, partition_filters: Option>, @@ -513,13 +628,10 @@ impl RawDeltaTable { post_commithook_properties: Option, ) -> PyResult { let (table, metrics) = py.allow_threads(|| { - let mut cmd = OptimizeBuilder::new( - self._table.log_store(), - self._table.snapshot().map_err(PythonError::from)?.clone(), - ) - .with_max_concurrent_tasks(max_concurrent_tasks.unwrap_or_else(num_cpus::get)) - .with_max_spill_size(max_spill_size) - .with_type(OptimizeType::ZOrder(z_order_columns)); + let mut cmd = OptimizeBuilder::new(self.log_store()?, self.cloned_state()?) + .with_max_concurrent_tasks(max_concurrent_tasks.unwrap_or_else(num_cpus::get)) + .with_max_spill_size(max_spill_size) + .with_type(OptimizeType::ZOrder(z_order_columns)); if let Some(size) = target_size { cmd = cmd.with_target_size(size); } @@ -544,25 +656,24 @@ impl RawDeltaTable { .map_err(PythonError::from)?; cmd = cmd.with_filters(&converted_filters); - rt().block_on(cmd.into_future()).map_err(PythonError::from) + rt().block_on(cmd.into_future()) + .map_err(PythonError::from) + .map_err(PyErr::from) })?; - self._table.state = table.state; + self.set_state(table.state)?; Ok(serde_json::to_string(&metrics).unwrap()) } #[pyo3(signature = (fields, commit_properties=None, post_commithook_properties=None))] pub fn add_columns( - &mut self, + &self, py: Python, fields: Vec, commit_properties: Option, post_commithook_properties: Option, ) -> PyResult<()> { let table = py.allow_threads(|| { - let mut cmd = AddColumnBuilder::new( - self._table.log_store(), - self._table.snapshot().map_err(PythonError::from)?.clone(), - ); + let mut cmd = AddColumnBuilder::new(self.log_store()?, self.cloned_state()?); let new_fields = fields .iter() @@ -576,15 +687,17 @@ impl RawDeltaTable { { cmd = cmd.with_commit_properties(commit_properties); } - rt().block_on(cmd.into_future()).map_err(PythonError::from) + rt().block_on(cmd.into_future()) + .map_err(PythonError::from) + .map_err(PyErr::from) })?; - self._table.state = table.state; + self.set_state(table.state)?; Ok(()) } #[pyo3(signature = (feature, allow_protocol_versions_increase, commit_properties=None, post_commithook_properties=None))] pub fn add_feature( - &mut self, + &self, py: Python, feature: Vec, allow_protocol_versions_increase: bool, @@ -592,37 +705,33 @@ impl RawDeltaTable { post_commithook_properties: Option, ) -> PyResult<()> { let table = py.allow_threads(|| { - let mut cmd = AddTableFeatureBuilder::new( - self._table.log_store(), - self._table.snapshot().map_err(PythonError::from)?.clone(), - ) - .with_features(feature) - .with_allow_protocol_versions_increase(allow_protocol_versions_increase); + let mut cmd = AddTableFeatureBuilder::new(self.log_store()?, self.cloned_state()?) + .with_features(feature) + .with_allow_protocol_versions_increase(allow_protocol_versions_increase); if let Some(commit_properties) = maybe_create_commit_properties(commit_properties, post_commithook_properties) { cmd = cmd.with_commit_properties(commit_properties); } - rt().block_on(cmd.into_future()).map_err(PythonError::from) + rt().block_on(cmd.into_future()) + .map_err(PythonError::from) + .map_err(PyErr::from) })?; - self._table.state = table.state; + self.set_state(table.state)?; Ok(()) } #[pyo3(signature = (constraints, commit_properties=None, post_commithook_properties=None))] pub fn add_constraints( - &mut self, + &self, py: Python, constraints: HashMap, commit_properties: Option, post_commithook_properties: Option, ) -> PyResult<()> { let table = py.allow_threads(|| { - let mut cmd = ConstraintBuilder::new( - self._table.log_store(), - self._table.snapshot().map_err(PythonError::from)?.clone(), - ); + let mut cmd = ConstraintBuilder::new(self.log_store()?, self.cloned_state()?); for (col_name, expression) in constraints { cmd = cmd.with_constraint(col_name.clone(), expression.clone()); @@ -634,15 +743,17 @@ impl RawDeltaTable { cmd = cmd.with_commit_properties(commit_properties); } - rt().block_on(cmd.into_future()).map_err(PythonError::from) + rt().block_on(cmd.into_future()) + .map_err(PythonError::from) + .map_err(PyErr::from) })?; - self._table.state = table.state; + self.set_state(table.state)?; Ok(()) } #[pyo3(signature = (name, raise_if_not_exists, commit_properties=None, post_commithook_properties=None))] pub fn drop_constraints( - &mut self, + &self, py: Python, name: String, raise_if_not_exists: bool, @@ -650,12 +761,9 @@ impl RawDeltaTable { post_commithook_properties: Option, ) -> PyResult<()> { let table = py.allow_threads(|| { - let mut cmd = DropConstraintBuilder::new( - self._table.log_store(), - self._table.snapshot().map_err(PythonError::from)?.clone(), - ) - .with_constraint(name) - .with_raise_if_not_exists(raise_if_not_exists); + let mut cmd = DropConstraintBuilder::new(self.log_store()?, self.cloned_state()?) + .with_constraint(name) + .with_raise_if_not_exists(raise_if_not_exists); if let Some(commit_properties) = maybe_create_commit_properties(commit_properties, post_commithook_properties) @@ -663,16 +771,18 @@ impl RawDeltaTable { cmd = cmd.with_commit_properties(commit_properties); } - rt().block_on(cmd.into_future()).map_err(PythonError::from) + rt().block_on(cmd.into_future()) + .map_err(PythonError::from) + .map_err(PyErr::from) })?; - self._table.state = table.state; + self.set_state(table.state)?; Ok(()) } #[pyo3(signature = (starting_version = 0, ending_version = None, starting_timestamp = None, ending_timestamp = None, columns = None, allow_out_of_range = false))] #[allow(clippy::too_many_arguments)] pub fn load_cdf( - &mut self, + &self, py: Python, starting_version: i64, ending_version: Option, @@ -682,11 +792,8 @@ impl RawDeltaTable { allow_out_of_range: bool, ) -> PyResult> { let ctx = SessionContext::new(); - let mut cdf_read = CdfLoadBuilder::new( - self._table.log_store(), - self._table.snapshot().map_err(PythonError::from)?.clone(), - ) - .with_starting_version(starting_version); + let mut cdf_read = CdfLoadBuilder::new(self.log_store()?, self.cloned_state()?) + .with_starting_version(starting_version); if let Some(ev) = ending_version { cdf_read = cdf_read.with_ending_version(ev); @@ -767,8 +874,8 @@ impl RawDeltaTable { ) -> PyResult { py.allow_threads(|| { Ok(PyMergeBuilder::new( - self._table.log_store(), - self._table.snapshot().map_err(PythonError::from)?.clone(), + self.log_store()?, + self.cloned_state()?, source.0, predicate, source_alias, @@ -786,13 +893,13 @@ impl RawDeltaTable { merge_builder ))] pub fn merge_execute( - &mut self, + &self, py: Python, merge_builder: &mut PyMergeBuilder, ) -> PyResult { py.allow_threads(|| { let (table, metrics) = merge_builder.execute().map_err(PythonError::from)?; - self._table.state = table.state; + self.set_state(table.state)?; Ok(metrics) }) } @@ -800,16 +907,13 @@ impl RawDeltaTable { // Run the restore command on the Delta Table: restore table to a given version or datetime #[pyo3(signature = (target, *, ignore_missing_files = false, protocol_downgrade_allowed = false, commit_properties=None))] pub fn restore( - &mut self, + &self, target: Option<&Bound<'_, PyAny>>, ignore_missing_files: bool, protocol_downgrade_allowed: bool, commit_properties: Option, ) -> PyResult { - let mut cmd = RestoreBuilder::new( - self._table.log_store(), - self._table.snapshot().map_err(PythonError::from)?.clone(), - ); + let mut cmd = RestoreBuilder::new(self.log_store()?, self.cloned_state()?); if let Some(val) = target { if let Ok(version) = val.extract::() { cmd = cmd.with_version_to_restore(version) @@ -833,32 +937,45 @@ impl RawDeltaTable { let (table, metrics) = rt() .block_on(cmd.into_future()) .map_err(PythonError::from)?; - self._table.state = table.state; + self.set_state(table.state)?; Ok(serde_json::to_string(&metrics).unwrap()) } /// Run the History command on the Delta Table: Returns provenance information, including the operation, user, and so on, for each write to a table. #[pyo3(signature = (limit=None))] - pub fn history(&mut self, limit: Option) -> PyResult> { - let history = rt() - .block_on(self._table.history(limit)) - .map_err(PythonError::from)?; + pub fn history(&self, limit: Option) -> PyResult> { + let history = rt().block_on(async { + match self._table.lock() { + Ok(table) => table + .history(limit) + .await + .map_err(PythonError::from) + .map_err(PyErr::from), + Err(e) => Err(PyRuntimeError::new_err(e.to_string())), + } + })?; Ok(history .iter() .map(|c| serde_json::to_string(c).unwrap()) .collect()) } - pub fn update_incremental(&mut self) -> PyResult<()> { + pub fn update_incremental(&self) -> PyResult<()> { #[allow(deprecated)] Ok(rt() - .block_on(self._table.update_incremental(None)) + .block_on(async { + let mut table = self + ._table + .lock() + .map_err(|e| DeltaTableError::Generic(e.to_string()))?; + (*table).update_incremental(None).await + }) .map_err(PythonError::from)?) } #[pyo3(signature = (schema, partition_filters=None))] pub fn dataset_partitions<'py>( - &mut self, + &self, py: Python<'py>, schema: PyArrowType, partition_filters: Option>, @@ -869,9 +986,7 @@ impl RawDeltaTable { )), None => None, }; - self._table - .snapshot() - .map_err(PythonError::from)? + self.cloned_state()? .log_data() .into_iter() .filter_map(|f| { @@ -894,17 +1009,21 @@ impl RawDeltaTable { partitions_filters: Option>, py: Python<'py>, ) -> PyResult> { - let column_names: HashSet<&str> = self - ._table - .get_schema() - .map_err(|_| DeltaProtocolError::new_err("table does not yet have a schema"))? - .fields() - .map(|field| field.name().as_str()) - .collect(); - let partition_columns: HashSet<&str> = self - ._table - .metadata() - .map_err(PythonError::from)? + let schema = self.with_table(|t| { + t.get_schema() + .cloned() + .map_err(PythonError::from) + .map_err(PyErr::from) + })?; + let metadata = self.with_table(|t| { + t.metadata() + .cloned() + .map_err(PythonError::from) + .map_err(PyErr::from) + })?; + let column_names: HashSet<&str> = + schema.fields().map(|field| field.name().as_str()).collect(); + let partition_columns: HashSet<&str> = metadata .partition_columns .iter() .map(|col| col.as_str()) @@ -946,10 +1065,8 @@ impl RawDeltaTable { let partition_columns: Vec<&str> = partition_columns.into_iter().collect(); - let adds = self - ._table - .snapshot() - .map_err(PythonError::from)? + let state = self.cloned_state()?; + let adds = state .get_active_add_actions_by_partitions(&converted_filters) .map_err(PythonError::from)? .collect::, _>>() @@ -984,7 +1101,7 @@ impl RawDeltaTable { #[allow(clippy::too_many_arguments)] #[pyo3(signature = (add_actions, mode, partition_by, schema, partitions_filters=None, commit_properties=None, post_commithook_properties=None))] fn create_write_transaction( - &mut self, + &self, py: Python, add_actions: Vec, mode: &str, @@ -999,7 +1116,12 @@ impl RawDeltaTable { let schema: StructType = (&schema.0).try_into().map_err(PythonError::from)?; - let existing_schema = self._table.get_schema().map_err(PythonError::from)?; + let existing_schema = self.with_table(|t| { + t.get_schema() + .cloned() + .map_err(PythonError::from) + .map_err(PyErr::from) + })?; let mut actions: Vec = add_actions .iter() @@ -1012,10 +1134,8 @@ impl RawDeltaTable { convert_partition_filters(partitions_filters.unwrap_or_default()) .map_err(PythonError::from)?; - let add_actions = self - ._table - .snapshot() - .map_err(PythonError::from)? + let state = self.cloned_state()?; + let add_actions = state .get_active_add_actions_by_partitions(&converted_filters) .map_err(PythonError::from)?; @@ -1053,9 +1173,13 @@ impl RawDeltaTable { } // Update metadata with new schema - if &schema != existing_schema { - let mut metadata = - self._table.metadata().map_err(PythonError::from)?.clone(); + if schema != existing_schema { + let mut metadata = self.with_table(|t| { + t.metadata() + .cloned() + .map_err(PythonError::from) + .map_err(PyErr::from) + })?; metadata.schema_string = serde_json::to_string(&schema) .map_err(DeltaTableError::from) .map_err(PythonError::from)?; @@ -1064,7 +1188,7 @@ impl RawDeltaTable { } _ => { // This should be unreachable from Python - if &schema != existing_schema { + if schema != existing_schema { DeltaProtocolError::new_err("Cannot change schema except in overwrite."); } } @@ -1096,11 +1220,7 @@ impl RawDeltaTable { rt().block_on( CommitBuilder::from(properties) .with_actions(actions) - .build( - Some(self._table.snapshot().map_err(PythonError::from)?), - self._table.log_store(), - operation, - ) + .build(Some(&self.cloned_state()?), self.log_store()?, operation) .into_future(), ) .map_err(PythonError::from)?; @@ -1111,7 +1231,7 @@ impl RawDeltaTable { pub fn get_py_storage_backend(&self) -> PyResult { Ok(filesystem::DeltaFileSystemHandler { - inner: self._table.object_store(), + inner: self.object_store()?, config: self._config.clone(), known_sizes: None, }) @@ -1119,10 +1239,15 @@ impl RawDeltaTable { pub fn create_checkpoint(&self, py: Python) -> PyResult<()> { py.allow_threads(|| { - Ok::<_, pyo3::PyErr>( - rt().block_on(create_checkpoint(&self._table)) - .map_err(PythonError::from)?, - ) + rt().block_on(async { + match self._table.lock() { + Ok(table) => create_checkpoint(&table) + .await + .map_err(PythonError::from) + .map_err(PyErr::from), + Err(e) => Err(PyRuntimeError::new_err(e.to_string())), + } + }) })?; Ok(()) @@ -1130,10 +1255,15 @@ impl RawDeltaTable { pub fn cleanup_metadata(&self, py: Python) -> PyResult<()> { py.allow_threads(|| { - Ok::<_, pyo3::PyErr>( - rt().block_on(cleanup_metadata(&self._table)) - .map_err(PythonError::from)?, - ) + rt().block_on(async { + match self._table.lock() { + Ok(table) => cleanup_metadata(&table) + .await + .map_err(PythonError::from) + .map_err(PyErr::from), + Err(e) => Err(PyRuntimeError::new_err(e.to_string())), + } + }) })?; Ok(()) @@ -1143,29 +1273,29 @@ impl RawDeltaTable { if !self.has_files()? { return Err(DeltaError::new_err("Table is instantiated without files.")); } - Ok(PyArrowType( - self._table - .snapshot() + Ok(PyArrowType(self.with_table(|t| { + Ok(t.snapshot() .map_err(PythonError::from)? .add_actions_table(flatten) - .map_err(PythonError::from)?, - )) + .map_err(PythonError::from)?) + })?)) } pub fn get_add_file_sizes(&self) -> PyResult> { - Ok(self - ._table - .snapshot() - .map_err(PythonError::from)? - .eager_snapshot() - .files() - .map(|f| (f.path().to_string(), f.size())) - .collect::>()) + self.with_table(|t| { + Ok(t.snapshot() + .map_err(PythonError::from)? + .snapshot() + .files() + .map(|f| (f.path().to_string(), f.size())) + .collect::>()) + }) } + /// Run the delete command on the delta table: delete records following a predicate and return the delete metrics. #[pyo3(signature = (predicate = None, writer_properties=None, commit_properties=None, post_commithook_properties=None))] pub fn delete( - &mut self, + &self, py: Python, predicate: Option, writer_properties: Option, @@ -1173,10 +1303,7 @@ impl RawDeltaTable { post_commithook_properties: Option, ) -> PyResult { let (table, metrics) = py.allow_threads(|| { - let mut cmd = DeleteBuilder::new( - self._table.log_store(), - self._table.snapshot().map_err(PythonError::from)?.clone(), - ); + let mut cmd = DeleteBuilder::new(self.log_store()?, self.cloned_state()?); if let Some(predicate) = predicate { cmd = cmd.with_predicate(predicate); } @@ -1191,25 +1318,24 @@ impl RawDeltaTable { cmd = cmd.with_commit_properties(commit_properties); } - rt().block_on(cmd.into_future()).map_err(PythonError::from) + rt().block_on(cmd.into_future()) + .map_err(PythonError::from) + .map_err(PyErr::from) })?; - self._table.state = table.state; + self.set_state(table.state)?; Ok(serde_json::to_string(&metrics).unwrap()) } #[pyo3(signature = (properties, raise_if_not_exists, commit_properties=None))] pub fn set_table_properties( - &mut self, + &self, properties: HashMap, raise_if_not_exists: bool, commit_properties: Option, ) -> PyResult<()> { - let mut cmd = SetTablePropertiesBuilder::new( - self._table.log_store(), - self._table.snapshot().map_err(PythonError::from)?.clone(), - ) - .with_properties(properties) - .with_raise_if_not_exists(raise_if_not_exists); + let mut cmd = SetTablePropertiesBuilder::new(self.log_store()?, self.cloned_state()?) + .with_properties(properties) + .with_raise_if_not_exists(raise_if_not_exists); if let Some(commit_properties) = maybe_create_commit_properties(commit_properties, None) { cmd = cmd.with_commit_properties(commit_properties); @@ -1218,7 +1344,7 @@ impl RawDeltaTable { let table = rt() .block_on(cmd.into_future()) .map_err(PythonError::from)?; - self._table.state = table.state; + self.set_state(table.state)?; Ok(()) } @@ -1226,16 +1352,13 @@ impl RawDeltaTable { /// have been deleted or are malformed #[pyo3(signature = (dry_run = true, commit_properties = None, post_commithook_properties=None))] pub fn repair( - &mut self, + &self, dry_run: bool, commit_properties: Option, post_commithook_properties: Option, ) -> PyResult { - let mut cmd = FileSystemCheckBuilder::new( - self._table.log_store(), - self._table.snapshot().map_err(PythonError::from)?.clone(), - ) - .with_dry_run(dry_run); + let mut cmd = FileSystemCheckBuilder::new(self.log_store()?, self.cloned_state()?) + .with_dry_run(dry_run); if let Some(commit_properties) = maybe_create_commit_properties(commit_properties, post_commithook_properties) @@ -1246,16 +1369,23 @@ impl RawDeltaTable { let (table, metrics) = rt() .block_on(cmd.into_future()) .map_err(PythonError::from)?; - self._table.state = table.state; + self.set_state(table.state)?; Ok(serde_json::to_string(&metrics).unwrap()) } pub fn transaction_versions(&self) -> HashMap { - self._table - .get_app_transaction_version() - .into_iter() - .map(|(app_id, transaction)| (app_id, PyTransaction::from(transaction))) - .collect() + let version = self.with_table(|t| Ok(t.get_app_transaction_version())); + + match version { + Ok(version) => version + .into_iter() + .map(|(app_id, transaction)| (app_id, PyTransaction::from(transaction))) + .collect(), + Err(e) => { + warn!("Cannot fetch transaction version due to {e:?}"); + HashMap::default() + } + } } fn __datafusion_table_provider__<'py>( @@ -1264,7 +1394,8 @@ impl RawDeltaTable { ) -> PyResult> { let name = CString::new("datafusion_table_provider").unwrap(); - let provider = FFI_TableProvider::new(Arc::new(self._table.clone()), false); + let table = self.with_table(|t| Ok(Arc::new(t.clone())))?; + let provider = FFI_TableProvider::new(table, false); PyCapsule::new_bound(py, provider, Some(name.clone())) } @@ -1786,7 +1917,7 @@ fn write_to_deltalake( let options = storage_options.clone().unwrap_or_default(); let table = if let Some(table) = table { - DeltaOps(table._table.clone()) + table.with_table(|t| Ok(DeltaOps::from(t.clone())))? } else { rt().block_on(DeltaOps::try_from_uri_with_storage_options( &table_uri, options, @@ -2015,17 +2146,16 @@ fn get_num_idx_cols_and_stats_columns( table: Option<&RawDeltaTable>, configuration: Option>>, ) -> PyResult<(i32, Option>)> { - let config = table - .as_ref() - .map(|table| table._table.snapshot()) - .transpose() - .map_err(PythonError::from)? - .map(|snapshot| snapshot.table_config()); - - Ok(deltalake::operations::get_num_idx_cols_and_stats_columns( - config, - configuration.unwrap_or_default(), - )) + match table.as_ref() { + Some(table) => Ok(deltalake::operations::get_num_idx_cols_and_stats_columns( + Some(table.cloned_state()?.table_config()), + configuration.unwrap_or_default(), + )), + None => Ok(deltalake::operations::get_num_idx_cols_and_stats_columns( + None, + configuration.unwrap_or_default(), + )), + } } #[pyclass(name = "DeltaDataChecker", module = "deltalake._internal")] diff --git a/python/src/query.rs b/python/src/query.rs index 55889c567f..ce09cf46a8 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -35,15 +35,15 @@ impl PyQueryBuilder { /// Once called, the provided `delta_table` will be referencable in SQL queries so long as /// another table of the same name is not registered over it. pub fn register(&self, table_name: &str, delta_table: &RawDeltaTable) -> PyResult<()> { - let snapshot = delta_table._table.snapshot().map_err(PythonError::from)?; - let log_store = delta_table._table.log_store(); + let snapshot = delta_table.cloned_state()?; + let log_store = delta_table.log_store()?; let scan_config = DeltaScanConfigBuilder::default() - .build(snapshot) + .build(&snapshot) .map_err(PythonError::from)?; let provider = Arc::new( - DeltaTableProvider::try_new(snapshot.clone(), log_store, scan_config) + DeltaTableProvider::try_new(snapshot, log_store, scan_config) .map_err(PythonError::from)?, ); diff --git a/python/tests/test_threaded.py b/python/tests/test_threaded.py new file mode 100644 index 0000000000..b9a4a97908 --- /dev/null +++ b/python/tests/test_threaded.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# +# This filue contains all the tests of the deltalake python package in a +# multithreaded environment + +import pathlib +import threading +from concurrent.futures import ThreadPoolExecutor + +import pyarrow as pa +import pytest + +from deltalake import DeltaTable, write_deltalake +from deltalake.exceptions import CommitFailedError + + +def test_concurrency(existing_table: DeltaTable, sample_data: pa.Table): + exception = None + + def comp(): + nonlocal exception + dt = DeltaTable(existing_table.table_uri) + for _ in range(5): + # We should always be able to get a consistent table state + data = DeltaTable(dt.table_uri).to_pyarrow_table() + # If two overwrites delete the same file and then add their own + # concurrently, then this will fail. + assert data.num_rows == sample_data.num_rows + try: + write_deltalake(dt.table_uri, sample_data, mode="overwrite") + except Exception as e: + exception = e + + n_threads = 2 + threads = [threading.Thread(target=comp) for _ in range(n_threads)] + + for t in threads: + t.start() + for t in threads: + t.join() + + assert isinstance(exception, CommitFailedError) + assert ( + "a concurrent transaction deleted the same data your transaction deletes" + in str(exception) + ) + + +@pytest.mark.polars +def test_multithreaded_write(sample_data: pa.Table, tmp_path: pathlib.Path): + import polars as pl + + table = pl.DataFrame({"a": [1, 2, 3]}).to_arrow() + write_deltalake(tmp_path, table, mode="overwrite") + + dt = DeltaTable(tmp_path) + + with ThreadPoolExecutor() as exe: + list(exe.map(lambda _: write_deltalake(dt, table, mode="append"), range(5))) diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index e58cb545ea..11320743e0 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -3,7 +3,6 @@ import os import pathlib import random -import threading from datetime import date, datetime, timezone from decimal import Decimal from math import inf @@ -18,7 +17,6 @@ from deltalake import DeltaTable, Schema, write_deltalake from deltalake.exceptions import ( - CommitFailedError, DeltaError, DeltaProtocolError, SchemaMismatchError, @@ -1432,38 +1430,6 @@ def test_uint_arrow_types(tmp_path: pathlib.Path): write_deltalake(tmp_path, table) -def test_concurrency(existing_table: DeltaTable, sample_data: pa.Table): - exception = None - - def comp(): - nonlocal exception - dt = DeltaTable(existing_table.table_uri) - for _ in range(5): - # We should always be able to get a consistent table state - data = DeltaTable(dt.table_uri).to_pyarrow_table() - # If two overwrites delete the same file and then add their own - # concurrently, then this will fail. - assert data.num_rows == sample_data.num_rows - try: - write_deltalake(dt.table_uri, sample_data, mode="overwrite") - except Exception as e: - exception = e - - n_threads = 2 - threads = [threading.Thread(target=comp) for _ in range(n_threads)] - - for t in threads: - t.start() - for t in threads: - t.join() - - assert isinstance(exception, CommitFailedError) - assert ( - "a concurrent transaction deleted the same data your transaction deletes" - in str(exception) - ) - - def test_issue_1651_roundtrip_timestamp(tmp_path: pathlib.Path): data = pa.table( {