diff --git a/Cargo.toml b/Cargo.toml index cc8eead83c..be3ab70677 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,12 +48,13 @@ parquet = { version = "53" } datafusion = { version = "43" } datafusion-expr = { version = "43" } datafusion-common = { version = "43" } -datafusion-proto = { version = "43" } -datafusion-sql = { version = "43" } -datafusion-physical-expr = { version = "43" } -datafusion-physical-plan = { version = "43" } +datafusion-ffi = { version = "43" } datafusion-functions = { version = "43" } datafusion-functions-aggregate = { version = "43" } +datafusion-physical-expr = { version = "43" } +datafusion-physical-plan = { version = "43" } +datafusion-proto = { version = "43" } +datafusion-sql = { version = "43" } # serde serde = { version = "1.0.194", features = ["derive"] } diff --git a/python/Cargo.toml b/python/Cargo.toml index 8f18b8fb2e..2f85cf45fb 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -20,6 +20,9 @@ delta_kernel.workspace = true # arrow arrow-schema = { workspace = true, features = ["serde"] } +# datafusion +datafusion-ffi = { workspace = true } + # serde serde = { workspace = true } serde_json = { workspace = true } diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 66b5dc8f8f..026d84d08d 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -222,6 +222,7 @@ class RawDeltaTable: ending_timestamp: Optional[str] = None, ) -> pyarrow.RecordBatchReader: ... def transaction_versions(self) -> Dict[str, Transaction]: ... + def __datafusion_table_provider__(self) -> Any: ... def rust_core_version() -> str: ... def write_new_deltalake( diff --git a/python/deltalake/table.py b/python/deltalake/table.py index e54a1c3f8c..247a2b9527 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -1423,6 +1423,43 @@ def repair( def transaction_versions(self) -> Dict[str, Transaction]: return self._table.transaction_versions() + def __datafusion_table_provider__(self) -> Any: + """Return the DataFusion table provider PyCapsule interface. + + To support DataFusion features such as push down filtering, this function will return a PyCapsule + interface that conforms to the FFI Table Provider required by DataFusion. From an end user perspective + you should not need to call this function directly. Instead you can use ``register_table_provider`` in + the DataFusion SessionContext. + + Returns: + A PyCapsule DataFusion TableProvider interface. + + Example: + ```python + from deltalake import DeltaTable, write_deltalake + from datafusion import SessionContext + import pyarrow as pa + data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + write_deltalake("tmp", data) + dt = DeltaTable("tmp") + ctx = SessionContext() + ctx.register_table_provider("test", table) + ctx.table("test").show() + ``` + Results in + ``` + DataFrame() + +----+----+----+ + | c3 | c1 | c2 | + +----+----+----+ + | 4 | 6 | a | + | 6 | 5 | b | + | 5 | 4 | c | + +----+----+----+ + ``` + """ + return self._table.__datafusion_table_provider__() + class TableMerger: """API for various table `MERGE` commands.""" diff --git a/python/src/lib.rs b/python/src/lib.rs index 361f094f38..5a9a3ce237 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -6,13 +6,16 @@ mod schema; mod utils; use std::collections::{HashMap, HashSet}; +use std::ffi::CString; use std::future::IntoFuture; use std::str::FromStr; +use std::sync::Arc; use std::time; use std::time::{SystemTime, UNIX_EPOCH}; use arrow::pyarrow::PyArrowType; use chrono::{DateTime, Duration, FixedOffset, Utc}; +use datafusion_ffi::table_provider::FFI_TableProvider; use delta_kernel::expressions::Scalar; use delta_kernel::schema::StructField; use deltalake::arrow::compute::concat_batches; @@ -58,7 +61,7 @@ use futures::future::join_all; use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; -use pyo3::types::{PyDict, PyFrozenSet}; +use pyo3::types::{PyCapsule, PyDict, PyFrozenSet}; use serde_json::{Map, Value}; use crate::error::DeltaProtocolError; @@ -1240,6 +1243,17 @@ impl RawDeltaTable { .map(|(app_id, transaction)| (app_id, PyTransaction::from(transaction))) .collect() } + + fn __datafusion_table_provider__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let name = CString::new("datafusion_table_provider").unwrap(); + + let provider = FFI_TableProvider::new(Arc::new(self._table.clone()), false); + + PyCapsule::new_bound(py, provider, Some(name.clone())) + } } fn set_post_commithook_properties(