From 57795da9d9cc86a460a5888713adfb3d0584b4cc Mon Sep 17 00:00:00 2001 From: Abhi Agarwal Date: Tue, 11 Jun 2024 16:19:03 -0400 Subject: [PATCH] chore: bump to datafusion 39, arrow 52, pyo3 0.21 (#2581) # Description Updates the arrow and datafusion dependencies to 52 and 39(-rc1) respectively. This is necessary for updating pyo3. While most changes with trivial, some required big rewrites. Namely, the logic for the Updates operation had to be rewritten (and simplified) to accommodate some new sanity checks inside datafusion: (https://github.com/apache/datafusion/pull/10088). Depends on delta-kernel having its arrow and object-store version bumped as well. This PR doesn't include any major changes for pyo3, I'll open a separate PR depending on this PR. # Related Issue(s) # Documentation --------- Co-authored-by: R. Tyler Croy --- Cargo.toml | 44 +++--- crates/aws/src/lib.rs | 6 +- crates/aws/src/storage.rs | 23 ++-- crates/aws/tests/common.rs | 6 +- crates/aws/tests/repair_s3_rename_test.rs | 22 ++- crates/azure/tests/integration.rs | 5 +- crates/benchmarks/src/bin/merge.rs | 3 +- crates/core/src/delta_datafusion/cdf/scan.rs | 2 +- .../src/delta_datafusion/cdf/scan_utils.rs | 1 + crates/core/src/delta_datafusion/expr.rs | 66 +++------ .../delta_datafusion/find_files/logical.rs | 13 +- .../src/delta_datafusion/find_files/mod.rs | 9 +- .../delta_datafusion/find_files/physical.rs | 2 +- crates/core/src/delta_datafusion/logical.rs | 15 ++- crates/core/src/delta_datafusion/mod.rs | 123 ++++------------- crates/core/src/delta_datafusion/physical.rs | 4 +- .../core/src/kernel/snapshot/log_segment.rs | 21 ++- crates/core/src/operations/cdc.rs | 4 +- crates/core/src/operations/delete.rs | 6 +- crates/core/src/operations/merge/barrier.rs | 25 ++-- crates/core/src/operations/merge/mod.rs | 10 +- crates/core/src/operations/transaction/mod.rs | 11 +- .../core/src/operations/transaction/state.rs | 12 +- crates/core/src/operations/update.rs | 125 +++++++----------- crates/core/src/operations/write.rs | 8 +- crates/core/src/operations/writer.rs | 2 +- crates/core/src/protocol/checkpoints.rs | 6 +- crates/core/src/storage/file.rs | 22 +-- crates/core/src/storage/retry_ext.rs | 5 +- crates/core/src/writer/json.rs | 4 +- crates/core/src/writer/record_batch.rs | 4 +- crates/core/tests/fs_common/mod.rs | 21 ++- crates/gcp/src/storage.rs | 21 ++- crates/gcp/tests/context.rs | 2 +- crates/mount/src/file.rs | 22 +-- crates/sql/src/logical_plan.rs | 44 ++++-- crates/sql/src/planner.rs | 19 +-- crates/test/src/lib.rs | 2 +- python/Cargo.toml | 6 +- python/src/filesystem.rs | 68 ++++------ python/src/lib.rs | 2 - python/tests/test_update.py | 2 +- python/tests/test_writer.py | 13 ++ 43 files changed, 368 insertions(+), 463 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4cc6bad1b9..094eed958a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,33 +26,33 @@ debug = true debug = "line-tables-only" [workspace.dependencies] -delta_kernel = { version = "0.1" } +delta_kernel = { version = "0.1.1" } # delta_kernel = { path = "../delta-kernel-rs/kernel" } # arrow -arrow = { version = "51" } -arrow-arith = { version = "51" } -arrow-array = { version = "51", features = ["chrono-tz"] } -arrow-buffer = { version = "51" } -arrow-cast = { version = "51" } -arrow-ipc = { version = "51" } -arrow-json = { version = "51" } -arrow-ord = { version = "51" } -arrow-row = { version = "51" } -arrow-schema = { version = "51" } -arrow-select = { version = "51" } -object_store = { version = "0.9" } -parquet = { version = "51" } +arrow = { version = "52" } +arrow-arith = { version = "52" } +arrow-array = { version = "52", features = ["chrono-tz"] } +arrow-buffer = { version = "52" } +arrow-cast = { version = "52" } +arrow-ipc = { version = "52" } +arrow-json = { version = "52" } +arrow-ord = { version = "52" } +arrow-row = { version = "52" } +arrow-schema = { version = "52" } +arrow-select = { version = "52" } +object_store = { version = "0.10.1" } +parquet = { version = "52" } # datafusion -datafusion = { version = "37.1" } -datafusion-expr = { version = "37.1" } -datafusion-common = { version = "37.1" } -datafusion-proto = { version = "37.1" } -datafusion-sql = { version = "37.1" } -datafusion-physical-expr = { version = "37.1" } -datafusion-functions = { version = "37.1" } -datafusion-functions-array = { version = "37.1" } +datafusion = { version = "39" } +datafusion-expr = { version = "39" } +datafusion-common = { version = "39" } +datafusion-proto = { version = "39" } +datafusion-sql = { version = "39" } +datafusion-physical-expr = { version = "39" } +datafusion-functions = { version = "39" } +datafusion-functions-array = { version = "39" } # serde serde = { version = "1.0.194", features = ["derive"] } diff --git a/crates/aws/src/lib.rs b/crates/aws/src/lib.rs index acca602951..a0a99c01f0 100644 --- a/crates/aws/src/lib.rs +++ b/crates/aws/src/lib.rs @@ -189,15 +189,15 @@ impl DynamoDbLockClient { if dynamodb_override_endpoint exists/AWS_ENDPOINT_URL_DYNAMODB is specified by user use dynamodb_override_endpoint to create dynamodb client */ - let dynamodb_sdk_config = match dynamodb_override_endpoint { + + match dynamodb_override_endpoint { Some(dynamodb_endpoint_url) => sdk_config .to_owned() .to_builder() .endpoint_url(dynamodb_endpoint_url) .build(), None => sdk_config.to_owned(), - }; - dynamodb_sdk_config + } } /// Create the lock table where DynamoDb stores the commit information for all delta tables. diff --git a/crates/aws/src/storage.rs b/crates/aws/src/storage.rs index 7485b21761..4625bb6be9 100644 --- a/crates/aws/src/storage.rs +++ b/crates/aws/src/storage.rs @@ -5,8 +5,8 @@ use aws_config::provider_config::ProviderConfig; use aws_config::{Region, SdkConfig}; use bytes::Bytes; use deltalake_core::storage::object_store::{ - aws::AmazonS3ConfigKey, parse_url_opts, GetOptions, GetResult, ListResult, MultipartId, - ObjectMeta, ObjectStore, PutOptions, PutResult, Result as ObjectStoreResult, + aws::AmazonS3ConfigKey, parse_url_opts, GetOptions, GetResult, ListResult, ObjectMeta, + ObjectStore, PutOptions, PutResult, Result as ObjectStoreResult, }; use deltalake_core::storage::{ limit_store_handler, str_is_truthy, ObjectStoreFactory, ObjectStoreRef, StorageOptions, @@ -14,13 +14,13 @@ use deltalake_core::storage::{ use deltalake_core::{DeltaResult, ObjectStoreError, Path}; use futures::stream::BoxStream; use futures::Future; +use object_store::{MultipartUpload, PutMultipartOpts, PutPayload}; use std::collections::HashMap; use std::fmt::Debug; use std::ops::Range; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; -use tokio::io::AsyncWrite; use url::Url; use crate::errors::DynamoDbConfigError; @@ -334,14 +334,14 @@ impl std::fmt::Debug for S3StorageBackend { #[async_trait::async_trait] impl ObjectStore for S3StorageBackend { - async fn put(&self, location: &Path, bytes: Bytes) -> ObjectStoreResult { + async fn put(&self, location: &Path, bytes: PutPayload) -> ObjectStoreResult { self.inner.put(location, bytes).await } async fn put_opts( &self, location: &Path, - bytes: Bytes, + bytes: PutPayload, options: PutOptions, ) -> ObjectStoreResult { self.inner.put_opts(location, bytes, options).await @@ -402,19 +402,16 @@ impl ObjectStore for S3StorageBackend { } } - async fn put_multipart( - &self, - location: &Path, - ) -> ObjectStoreResult<(MultipartId, Box)> { + async fn put_multipart(&self, location: &Path) -> ObjectStoreResult> { self.inner.put_multipart(location).await } - async fn abort_multipart( + async fn put_multipart_opts( &self, location: &Path, - multipart_id: &MultipartId, - ) -> ObjectStoreResult<()> { - self.inner.abort_multipart(location, multipart_id).await + options: PutMultipartOpts, + ) -> ObjectStoreResult> { + self.inner.put_multipart_opts(location, options).await } } diff --git a/crates/aws/tests/common.rs b/crates/aws/tests/common.rs index 01aa505b1b..dfa2a9cd51 100644 --- a/crates/aws/tests/common.rs +++ b/crates/aws/tests/common.rs @@ -87,7 +87,7 @@ impl S3Integration { "dynamodb", "create-table", "--table-name", - &table_name, + table_name, "--provisioned-throughput", "ReadCapacityUnits=1,WriteCapacityUnits=1", "--attribute-definitions", @@ -112,7 +112,7 @@ impl S3Integration { } fn wait_for_table(table_name: &str) -> std::io::Result<()> { - let args = ["dynamodb", "describe-table", "--table-name", &table_name]; + let args = ["dynamodb", "describe-table", "--table-name", table_name]; loop { let output = Command::new("aws") .args(args) @@ -145,7 +145,7 @@ impl S3Integration { fn delete_dynamodb_table(table_name: &str) -> std::io::Result { let mut child = Command::new("aws") - .args(["dynamodb", "delete-table", "--table-name", &table_name]) + .args(["dynamodb", "delete-table", "--table-name", table_name]) .stdout(Stdio::null()) .spawn() .expect("aws command is installed"); diff --git a/crates/aws/tests/repair_s3_rename_test.rs b/crates/aws/tests/repair_s3_rename_test.rs index 68d8727ebe..d9e19de7b7 100644 --- a/crates/aws/tests/repair_s3_rename_test.rs +++ b/crates/aws/tests/repair_s3_rename_test.rs @@ -9,6 +9,7 @@ use deltalake_core::storage::object_store::{ use deltalake_core::{DeltaTableBuilder, ObjectStore, Path}; use deltalake_test::utils::IntegrationContext; use futures::stream::BoxStream; +use object_store::{MultipartUpload, PutMultipartOpts, PutPayload}; use serial_test::serial; use std::ops::Range; use std::sync::{Arc, Mutex}; @@ -60,8 +61,8 @@ async fn run_repair_test_case(path: &str, pause_copy: bool) -> Result<(), Object }; let (s3_2, _) = create_s3_backend(&context, "w2", None, None); - s3_1.put(&src1, Bytes::from("test1")).await.unwrap(); - s3_2.put(&src2, Bytes::from("test2")).await.unwrap(); + s3_1.put(&src1, Bytes::from("test1").into()).await.unwrap(); + s3_2.put(&src2, Bytes::from("test2").into()).await.unwrap(); let rename1 = rename(s3_1, &src1, &dst1); // to ensure that first one is started actually first @@ -166,14 +167,14 @@ impl ObjectStore for DelayedObjectStore { self.delete(from).await } - async fn put(&self, location: &Path, bytes: Bytes) -> ObjectStoreResult { + async fn put(&self, location: &Path, bytes: PutPayload) -> ObjectStoreResult { self.inner.put(location, bytes).await } async fn put_opts( &self, location: &Path, - bytes: Bytes, + bytes: PutPayload, options: PutOptions, ) -> ObjectStoreResult { self.inner.put_opts(location, bytes, options).await @@ -227,19 +228,16 @@ impl ObjectStore for DelayedObjectStore { self.inner.rename_if_not_exists(from, to).await } - async fn put_multipart( - &self, - location: &Path, - ) -> ObjectStoreResult<(MultipartId, Box)> { + async fn put_multipart(&self, location: &Path) -> ObjectStoreResult> { self.inner.put_multipart(location).await } - async fn abort_multipart( + async fn put_multipart_opts( &self, location: &Path, - multipart_id: &MultipartId, - ) -> ObjectStoreResult<()> { - self.inner.abort_multipart(location, multipart_id).await + options: PutMultipartOpts, + ) -> ObjectStoreResult> { + self.inner.put_multipart_opts(location, options).await } } diff --git a/crates/azure/tests/integration.rs b/crates/azure/tests/integration.rs index 5230462c92..3ffaa00cc5 100644 --- a/crates/azure/tests/integration.rs +++ b/crates/azure/tests/integration.rs @@ -75,7 +75,10 @@ async fn read_write_test_onelake(context: &IntegrationContext, path: &Path) -> T let expected = Bytes::from_static(b"test world from delta-rs on friday"); - delta_store.put(path, expected.clone()).await.unwrap(); + delta_store + .put(path, expected.clone().into()) + .await + .unwrap(); let fetched = delta_store.get(path).await.unwrap().bytes().await.unwrap(); assert_eq!(expected, fetched); diff --git a/crates/benchmarks/src/bin/merge.rs b/crates/benchmarks/src/bin/merge.rs index bb178a192d..2465e23d94 100644 --- a/crates/benchmarks/src/bin/merge.rs +++ b/crates/benchmarks/src/bin/merge.rs @@ -7,9 +7,10 @@ use arrow::datatypes::Schema as ArrowSchema; use arrow_array::{RecordBatch, StringArray, UInt32Array}; use chrono::Duration; use clap::{command, Args, Parser, Subcommand}; +use datafusion::functions::expr_fn::random; use datafusion::{datasource::MemTable, prelude::DataFrame}; use datafusion_common::DataFusionError; -use datafusion_expr::{cast, col, lit, random}; +use datafusion_expr::{cast, col, lit}; use deltalake_core::protocol::SaveMode; use deltalake_core::{ arrow::{ diff --git a/crates/core/src/delta_datafusion/cdf/scan.rs b/crates/core/src/delta_datafusion/cdf/scan.rs index 1f9c9f52b3..ea34855b77 100644 --- a/crates/core/src/delta_datafusion/cdf/scan.rs +++ b/crates/core/src/delta_datafusion/cdf/scan.rs @@ -38,7 +38,7 @@ impl ExecutionPlan for DeltaCdfScan { self.plan.properties() } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/crates/core/src/delta_datafusion/cdf/scan_utils.rs b/crates/core/src/delta_datafusion/cdf/scan_utils.rs index 434afa4f74..79d7a2359e 100644 --- a/crates/core/src/delta_datafusion/cdf/scan_utils.rs +++ b/crates/core/src/delta_datafusion/cdf/scan_utils.rs @@ -83,6 +83,7 @@ pub fn create_partition_values( partition_values: new_part_values.clone(), extensions: None, range: None, + statistics: None, }; file_groups.entry(new_part_values).or_default().push(part); diff --git a/crates/core/src/delta_datafusion/expr.rs b/crates/core/src/delta_datafusion/expr.rs index 41e6a84b4f..2d48f7873e 100644 --- a/crates/core/src/delta_datafusion/expr.rs +++ b/crates/core/src/delta_datafusion/expr.rs @@ -32,7 +32,7 @@ use datafusion::execution::context::SessionState; use datafusion_common::Result as DFResult; use datafusion_common::{config::ConfigOptions, DFSchema, Result, ScalarValue, TableReference}; use datafusion_expr::{ - expr::InList, AggregateUDF, Between, BinaryExpr, Cast, Expr, GetIndexedField, Like, TableSource, + expr::InList, AggregateUDF, Between, BinaryExpr, Cast, Expr, Like, TableSource, }; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::ast::escape_quoted_string; @@ -49,7 +49,7 @@ pub(crate) struct DeltaContextProvider<'a> { } impl<'a> ContextProvider for DeltaContextProvider<'a> { - fn get_table_provider(&self, _name: TableReference) -> DFResult> { + fn get_table_source(&self, _name: TableReference) -> DFResult> { unimplemented!() } @@ -73,19 +73,15 @@ impl<'a> ContextProvider for DeltaContextProvider<'a> { self.state.window_functions().get(name).cloned() } - fn get_table_source(&self, _name: TableReference) -> DFResult> { - unimplemented!() - } - - fn udfs_names(&self) -> Vec { + fn udf_names(&self) -> Vec { unimplemented!() } - fn udafs_names(&self) -> Vec { + fn udaf_names(&self) -> Vec { unimplemented!() } - fn udwfs_names(&self) -> Vec { + fn udwf_names(&self) -> Vec { unimplemented!() } } @@ -198,7 +194,7 @@ impl<'a> Display for SqlFormat<'a> { Expr::IsNotFalse(expr) => write!(f, "{} IS NOT FALSE", SqlFormat { expr }), Expr::IsNotUnknown(expr) => write!(f, "{} IS NOT UNKNOWN", SqlFormat { expr }), Expr::BinaryExpr(expr) => write!(f, "{}", BinaryExprFormat { expr }), - Expr::ScalarFunction(func) => fmt_function(f, func.func_def.name(), false, &func.args), + Expr::ScalarFunction(func) => fmt_function(f, func.func.name(), false, &func.args), Expr::Cast(Cast { expr, data_type }) => { write!(f, "arrow_cast({}, '{}')", SqlFormat { expr }, data_type) } @@ -276,33 +272,6 @@ impl<'a> Display for SqlFormat<'a> { write!(f, "{expr} IN ({})", expr_vec_fmt!(list)) } } - Expr::GetIndexedField(GetIndexedField { expr, field }) => match field { - datafusion_expr::GetFieldAccess::NamedStructField { name } => { - write!( - f, - "{}[{}]", - SqlFormat { expr }, - ScalarValueFormat { scalar: name } - ) - } - datafusion_expr::GetFieldAccess::ListIndex { key } => { - write!(f, "{}[{}]", SqlFormat { expr }, SqlFormat { expr: key }) - } - datafusion_expr::GetFieldAccess::ListRange { - start, - stop, - stride, - } => { - write!( - f, - "{expr}[{start}:{stop}:{stride}]", - expr = SqlFormat { expr }, - start = SqlFormat { expr: start }, - stop = SqlFormat { expr: stop }, - stride = SqlFormat { expr: stride } - ) - } - }, _ => Err(fmt::Error), } } @@ -428,11 +397,12 @@ mod test { use datafusion::prelude::SessionContext; use datafusion_common::{Column, ScalarValue, ToDFSchema}; use datafusion_expr::expr::ScalarFunction; - use datafusion_expr::{ - col, lit, substring, BinaryExpr, Cast, Expr, ExprSchemable, ScalarFunctionDefinition, - }; + use datafusion_expr::{col, lit, BinaryExpr, Cast, Expr, ExprSchemable}; use datafusion_functions::core::arrow_cast; + use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_functions::encoding::expr_fn::decode; + use datafusion_functions::expr_fn::substring; + use datafusion_functions_array::expr_ext::{IndexAccessor, SliceAccessor}; use datafusion_functions_array::expr_fn::cardinality; use crate::delta_datafusion::{DataFusionMixins, DeltaSessionContext}; @@ -564,7 +534,7 @@ mod test { override_expected_expr: Some( datafusion_expr::Expr::ScalarFunction( ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(arrow_cast()), + func: arrow_cast(), args: vec![ lit(ScalarValue::Int64(Some(1))), lit(ScalarValue::Utf8(Some("Int32".into()))) @@ -671,7 +641,7 @@ mod test { datafusion_expr::Expr::BinaryExpr(BinaryExpr { left: Box::new(datafusion_expr::Expr::ScalarFunction( ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(arrow_cast()), + func: arrow_cast(), args: vec![ col("value"), lit(ScalarValue::Utf8(Some("Utf8".into()))) @@ -685,19 +655,19 @@ mod test { }, simple!( col("_struct").field("a").eq(lit(20_i64)), - "_struct['a'] = 20".to_string() + "get_field(_struct, 'a') = 20".to_string() ), simple!( col("_struct").field("nested").field("b").eq(lit(20_i64)), - "_struct['nested']['b'] = 20".to_string() + "get_field(get_field(_struct, 'nested'), 'b') = 20".to_string() ), simple!( col("_list").index(lit(1_i64)).eq(lit(20_i64)), - "_list[1] = 20".to_string() + "array_element(_list, 1) = 20".to_string() ), simple!( cardinality(col("_list").range(col("value"), lit(10_i64))), - "cardinality(_list[value:10:1])".to_string() + "cardinality(array_slice(_list, value, 10))".to_string() ), ParseTest { expr: col("_timestamp_ntz").gt(lit(ScalarValue::TimestampMicrosecond(Some(1262304000000000), None))), @@ -705,7 +675,7 @@ mod test { override_expected_expr: Some(col("_timestamp_ntz").gt( datafusion_expr::Expr::ScalarFunction( ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(arrow_cast()), + func: arrow_cast(), args: vec![ lit(ScalarValue::Utf8(Some("2010-01-01T00:00:00.000000".into()))), lit(ScalarValue::Utf8(Some("Timestamp(Microsecond, None)".into()))) @@ -723,7 +693,7 @@ mod test { override_expected_expr: Some(col("_timestamp").gt( datafusion_expr::Expr::ScalarFunction( ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(arrow_cast()), + func: arrow_cast(), args: vec![ lit(ScalarValue::Utf8(Some("2010-01-01T00:00:00.000000".into()))), lit(ScalarValue::Utf8(Some("Timestamp(Microsecond, Some(\"UTC\"))".into()))) diff --git a/crates/core/src/delta_datafusion/find_files/logical.rs b/crates/core/src/delta_datafusion/find_files/logical.rs index 6234cbe5c2..4dd4a3b5da 100644 --- a/crates/core/src/delta_datafusion/find_files/logical.rs +++ b/crates/core/src/delta_datafusion/find_files/logical.rs @@ -92,7 +92,16 @@ impl UserDefinedLogicalNodeCore for FindFilesNode { ) } - fn from_template(&self, _exprs: &[Expr], _inputs: &[LogicalPlan]) -> Self { - self.clone() + fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + self.with_exprs_and_inputs(exprs.to_vec(), inputs.to_vec()) + .unwrap() + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + _inputs: Vec, + ) -> datafusion_common::Result { + Ok(self.clone()) } } diff --git a/crates/core/src/delta_datafusion/find_files/mod.rs b/crates/core/src/delta_datafusion/find_files/mod.rs index 2e8d26dee3..347925f31f 100644 --- a/crates/core/src/delta_datafusion/find_files/mod.rs +++ b/crates/core/src/delta_datafusion/find_files/mod.rs @@ -28,8 +28,6 @@ use crate::logstore::LogStoreRef; use crate::table::state::DeltaTableState; use crate::DeltaTableError; -use super::create_physical_expr_fix; - pub mod logical; pub mod physical; @@ -161,11 +159,8 @@ async fn scan_table_by_files( let input_schema = scan.logical_schema.as_ref().to_owned(); let input_dfschema = input_schema.clone().try_into()?; - let predicate_expr = create_physical_expr_fix( - Expr::IsTrue(Box::new(expression.clone())), - &input_dfschema, - state.execution_props(), - )?; + let predicate_expr = + state.create_physical_expr(Expr::IsTrue(Box::new(expression.clone())), &input_dfschema)?; let filter: Arc = Arc::new(FilterExec::try_new(predicate_expr, scan.clone())?); diff --git a/crates/core/src/delta_datafusion/find_files/physical.rs b/crates/core/src/delta_datafusion/find_files/physical.rs index eb09d2d94b..eb295912a2 100644 --- a/crates/core/src/delta_datafusion/find_files/physical.rs +++ b/crates/core/src/delta_datafusion/find_files/physical.rs @@ -97,7 +97,7 @@ impl ExecutionPlan for FindFilesExec { &self.plan_properties } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/crates/core/src/delta_datafusion/logical.rs b/crates/core/src/delta_datafusion/logical.rs index 52ee1194f4..2ce435b5b6 100644 --- a/crates/core/src/delta_datafusion/logical.rs +++ b/crates/core/src/delta_datafusion/logical.rs @@ -52,13 +52,22 @@ impl UserDefinedLogicalNodeCore for MetricObserver { fn from_template( &self, - _exprs: &[datafusion_expr::Expr], + exprs: &[datafusion_expr::Expr], inputs: &[datafusion_expr::LogicalPlan], ) -> Self { - MetricObserver { + self.with_exprs_and_inputs(exprs.to_vec(), inputs.to_vec()) + .unwrap() + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> datafusion_common::Result { + Ok(MetricObserver { id: self.id.clone(), input: inputs[0].clone(), enable_pushdown: self.enable_pushdown, - } + }) } } diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index 9c87411973..97f42497ea 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -56,21 +56,14 @@ use datafusion::physical_plan::{ Statistics, }; use datafusion_common::scalar::ScalarValue; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::{ config::ConfigOptions, Column, DFSchema, DataFusionError, Result as DataFusionResult, ToDFSchema, }; -use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::CreateExternalTable; use datafusion_expr::utils::conjunction; -use datafusion_expr::{ - col, Expr, Extension, GetFieldAccess, GetIndexedField, LogicalPlan, - TableProviderFilterPushDown, Volatility, -}; -use datafusion_functions::expr_fn::get_field; -use datafusion_functions_array::extract::{array_element, array_slice}; -use datafusion_physical_expr::execution_props::ExecutionProps; +use datafusion_expr::{col, Expr, Extension, LogicalPlan, TableProviderFilterPushDown, Volatility}; use datafusion_physical_expr::PhysicalExpr; use datafusion_proto::logical_plan::LogicalExtensionCodec; use datafusion_proto::physical_plan::PhysicalExtensionCodec; @@ -249,7 +242,9 @@ pub(crate) fn files_matching_predicate<'a>( if let Some(Some(predicate)) = (!filters.is_empty()).then_some(conjunction(filters.iter().cloned())) { - let expr = logical_expr_to_physical_expr(predicate, snapshot.arrow_schema()?.as_ref()); + //let expr = logical_expr_to_physical_expr(predicate, snapshot.arrow_schema()?.as_ref()); + let expr = SessionContext::new() + .create_physical_expr(predicate, &snapshot.arrow_schema()?.to_dfschema()?)?; let pruning_predicate = PruningPredicate::try_new(expr, snapshot.arrow_schema()?)?; Ok(Either::Left( snapshot @@ -533,9 +528,11 @@ impl<'a> DeltaScanBuilder<'a> { logical_schema }; + let context = SessionContext::new(); + let df_schema = logical_schema.clone().to_dfschema()?; let logical_filter = self .filter - .map(|expr| logical_expr_to_physical_expr(expr, &logical_schema)); + .map(|expr| context.create_physical_expr(expr, &df_schema).unwrap()); // Perform Pruning of files to scan let files = match self.files { @@ -699,11 +696,11 @@ impl TableProvider for DeltaTable { Ok(Arc::new(scan)) } - fn supports_filter_pushdown( + fn supports_filters_pushdown( &self, - _filter: &Expr, - ) -> DataFusionResult { - Ok(TableProviderFilterPushDown::Inexact) + _filter: &[&Expr], + ) -> DataFusionResult> { + Ok(vec![TableProviderFilterPushDown::Inexact]) } fn statistics(&self) -> Option { @@ -778,11 +775,11 @@ impl TableProvider for DeltaTableProvider { Ok(Arc::new(scan)) } - fn supports_filter_pushdown( + fn supports_filters_pushdown( &self, - _filter: &Expr, - ) -> DataFusionResult { - Ok(TableProviderFilterPushDown::Inexact) + _filter: &[&Expr], + ) -> DataFusionResult> { + Ok(vec![TableProviderFilterPushDown::Inexact]) } fn statistics(&self) -> Option { @@ -830,8 +827,8 @@ impl ExecutionPlan for DeltaScan { self.parquet_scan.properties() } - fn children(&self) -> Vec> { - vec![self.parquet_scan.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.parquet_scan] } fn with_new_children( @@ -992,6 +989,7 @@ pub(crate) fn partitioned_file_from_action( partition_values, range: None, extensions: None, + statistics: None, } } @@ -1067,59 +1065,6 @@ pub(crate) fn to_correct_scalar_value( } } -pub(crate) fn logical_expr_to_physical_expr( - expr: Expr, - schema: &ArrowSchema, -) -> Arc { - let df_schema = schema.clone().to_dfschema().unwrap(); - let execution_props = ExecutionProps::new(); - create_physical_expr_fix(expr, &df_schema, &execution_props).unwrap() -} - -// TODO This should be removed after datafusion v38 -pub(crate) fn create_physical_expr_fix( - expr: Expr, - input_dfschema: &DFSchema, - execution_props: &ExecutionProps, -) -> Result, DataFusionError> { - // Support Expr::struct by rewriting expressions. - let expr = expr - .transform_up(&|expr| { - // see https://github.com/apache/datafusion/issues/10181 - // This is part of the function rewriter code in DataFusion inlined here temporarily - Ok(match expr { - Expr::GetIndexedField(GetIndexedField { - expr, - field: GetFieldAccess::NamedStructField { name }, - }) => { - let name = Expr::Literal(name); - Transformed::yes(get_field(*expr, name)) - } - // expr[idx] ==> array_element(expr, idx) - Expr::GetIndexedField(GetIndexedField { - expr, - field: GetFieldAccess::ListIndex { key }, - }) => Transformed::yes(array_element(*expr, *key)), - - // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) - Expr::GetIndexedField(GetIndexedField { - expr, - field: - GetFieldAccess::ListRange { - start, - stop, - stride, - }, - }) => Transformed::yes(array_slice(*expr, *start, *stop, *stride)), - - _ => Transformed::no(expr), - }) - })? - .data; - - datafusion_physical_expr::create_physical_expr(&expr, input_dfschema, execution_props) -} - pub(crate) async fn execute_plan_to_batch( state: &SessionState, plan: Arc, @@ -1388,7 +1333,7 @@ pub(crate) struct FindFilesExprProperties { /// Ensure only expressions that make sense are accepted, check for /// non-deterministic functions, and determine if the expression only contains /// partition columns -impl TreeNodeVisitor for FindFilesExprProperties { +impl TreeNodeVisitor<'_> for FindFilesExprProperties { type Node = Expr; fn f_down(&mut self, expr: &Self::Node) -> datafusion_common::Result { @@ -1419,28 +1364,20 @@ impl TreeNodeVisitor for FindFilesExprProperties { | Expr::IsNotUnknown(_) | Expr::Negative(_) | Expr::InList { .. } - | Expr::GetIndexedField(_) | Expr::Between(_) | Expr::Case(_) | Expr::Cast(_) | Expr::TryCast(_) => (), - Expr::ScalarFunction(ScalarFunction { func_def, .. }) => { - let v = match func_def { - datafusion_expr::ScalarFunctionDefinition::BuiltIn(f) => f.volatility(), - datafusion_expr::ScalarFunctionDefinition::UDF(u) => u.signature().volatility, - datafusion_expr::ScalarFunctionDefinition::Name(n) => { + Expr::ScalarFunction(scalar_function) => { + match scalar_function.func.signature().volatility { + Volatility::Immutable => (), + _ => { self.result = Err(DeltaTableError::Generic(format!( - "Cannot determine volatility of find files predicate function {n}", + "Find files predicate contains nondeterministic function {}", + scalar_function.func.name() ))); return Ok(TreeNodeRecursion::Stop); } - }; - if v > Volatility::Immutable { - self.result = Err(DeltaTableError::Generic(format!( - "Find files predicate contains nondeterministic function {}", - func_def.name() - ))); - return Ok(TreeNodeRecursion::Stop); } } _ => { @@ -1551,11 +1488,8 @@ pub(crate) async fn find_files_scan<'a>( let input_schema = scan.logical_schema.as_ref().to_owned(); let input_dfschema = input_schema.clone().try_into()?; - let predicate_expr = create_physical_expr_fix( - Expr::IsTrue(Box::new(expression.clone())), - &input_dfschema, - state.execution_props(), - )?; + let predicate_expr = + state.create_physical_expr(Expr::IsTrue(Box::new(expression.clone())), &input_dfschema)?; let filter: Arc = Arc::new(FilterExec::try_new(predicate_expr, scan.clone())?); @@ -1902,6 +1836,7 @@ mod tests { partition_values: [ScalarValue::Int64(Some(2015)), ScalarValue::Int64(Some(1))].to_vec(), range: None, extensions: None, + statistics: None, }; assert_eq!(file.partition_values, ref_file.partition_values) } diff --git a/crates/core/src/delta_datafusion/physical.rs b/crates/core/src/delta_datafusion/physical.rs index 0251836fa8..bfc220cf86 100644 --- a/crates/core/src/delta_datafusion/physical.rs +++ b/crates/core/src/delta_datafusion/physical.rs @@ -86,8 +86,8 @@ impl ExecutionPlan for MetricObserverExec { self.parent.properties() } - fn children(&self) -> Vec> { - vec![self.parent.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.parent] } fn execute( diff --git a/crates/core/src/kernel/snapshot/log_segment.rs b/crates/core/src/kernel/snapshot/log_segment.rs index 2f76ac18d4..cfdbf2cb8a 100644 --- a/crates/core/src/kernel/snapshot/log_segment.rs +++ b/crates/core/src/kernel/snapshot/log_segment.rs @@ -655,13 +655,11 @@ pub(super) mod tests { mod slow_store { use std::sync::Arc; - use bytes::Bytes; use futures::stream::BoxStream; use object_store::{ - path::Path, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, - PutOptions, PutResult, Result, + path::Path, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, + ObjectStore, PutMultipartOpts, PutOptions, PutPayload, PutResult, Result, }; - use tokio::io::AsyncWrite; #[derive(Debug)] pub(super) struct SlowListStore { @@ -679,24 +677,21 @@ pub(super) mod tests { async fn put_opts( &self, location: &Path, - bytes: Bytes, + bytes: PutPayload, opts: PutOptions, ) -> Result { self.store.put_opts(location, bytes, opts).await } - async fn put_multipart( - &self, - location: &Path, - ) -> Result<(MultipartId, Box)> { + async fn put_multipart(&self, location: &Path) -> Result> { self.store.put_multipart(location).await } - async fn abort_multipart( + async fn put_multipart_opts( &self, location: &Path, - multipart_id: &MultipartId, - ) -> Result<()> { - self.store.abort_multipart(location, multipart_id).await + opts: PutMultipartOpts, + ) -> Result> { + self.store.put_multipart_opts(location, opts).await } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { diff --git a/crates/core/src/operations/cdc.rs b/crates/core/src/operations/cdc.rs index 8338bfa52b..cc8bff2359 100644 --- a/crates/core/src/operations/cdc.rs +++ b/crates/core/src/operations/cdc.rs @@ -208,8 +208,8 @@ impl ExecutionPlan for CDCObserver { self.parent.properties() } - fn children(&self) -> Vec> { - vec![self.parent.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.parent] } fn execute( diff --git a/crates/core/src/operations/delete.rs b/crates/core/src/operations/delete.rs index 4653920965..56aa9ef98b 100644 --- a/crates/core/src/operations/delete.rs +++ b/crates/core/src/operations/delete.rs @@ -37,8 +37,7 @@ use super::transaction::{CommitBuilder, CommitProperties, PROTOCOL}; use super::write::WriterStatsConfig; use crate::delta_datafusion::expr::fmt_expr_to_sql; use crate::delta_datafusion::{ - create_physical_expr_fix, find_files, register_store, DataFusionMixins, DeltaScanBuilder, - DeltaSessionContext, + find_files, register_store, DataFusionMixins, DeltaScanBuilder, DeltaSessionContext, }; use crate::errors::DeltaResult; use crate::kernel::{Action, Add, Remove}; @@ -149,8 +148,7 @@ async fn excute_non_empty_expr( // Apply the negation of the filter and rewrite files let negated_expression = Expr::Not(Box::new(Expr::IsTrue(Box::new(expression.clone())))); - let predicate_expr = - create_physical_expr_fix(negated_expression, &input_dfschema, state.execution_props())?; + let predicate_expr = state.create_physical_expr(negated_expression, &input_dfschema)?; let filter: Arc = Arc::new(FilterExec::try_new(predicate_expr, scan.clone())?); diff --git a/crates/core/src/operations/merge/barrier.rs b/crates/core/src/operations/merge/barrier.rs index 7d18843af7..04cde87a19 100644 --- a/crates/core/src/operations/merge/barrier.rs +++ b/crates/core/src/operations/merge/barrier.rs @@ -83,14 +83,14 @@ impl ExecutionPlan for MergeBarrierExec { vec![Distribution::HashPartitioned(vec![self.expr.clone()]); 1] } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn with_new_children( - self: std::sync::Arc, - children: Vec>, - ) -> datafusion_common::Result> { + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { if children.len() != 1 { return Err(DataFusionError::Plan( "MergeBarrierExec wrong number of children".to_string(), @@ -106,7 +106,7 @@ impl ExecutionPlan for MergeBarrierExec { fn execute( &self, partition: usize, - context: std::sync::Arc, + context: Arc, ) -> datafusion_common::Result { let input = self.input.execute(partition, context)?; Ok(Box::pin(MergeBarrierStream::new( @@ -422,11 +422,20 @@ impl UserDefinedLogicalNodeCore for MergeBarrier { exprs: &[datafusion_expr::Expr], inputs: &[datafusion_expr::LogicalPlan], ) -> Self { - MergeBarrier { + self.with_exprs_and_inputs(exprs.to_vec(), inputs.to_vec()) + .unwrap() + } + + fn with_exprs_and_inputs( + &self, + exprs: Vec, + inputs: Vec, + ) -> DataFusionResult { + Ok(MergeBarrier { input: inputs[0].clone(), file_column: self.file_column.clone(), expr: exprs[0].clone(), - } + }) } } diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index 6c783bc9b4..efc54c1869 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -502,7 +502,7 @@ impl MergeOperation { relation: Some(TableReference::Bare { table }), name, } => { - if table.eq(alias) { + if table.as_ref() == alias { Column { relation: Some(r), name, @@ -863,8 +863,8 @@ async fn try_construct_early_filter( table_snapshot: &DeltaTableState, session_state: &SessionState, source: &LogicalPlan, - source_name: &TableReference<'_>, - target_name: &TableReference<'_>, + source_name: &TableReference, + target_name: &TableReference, ) -> DeltaResult> { let table_metadata = table_snapshot.metadata(); let partition_columns = &table_metadata.partition_columns; @@ -1324,9 +1324,9 @@ async fn execute( let plan = projection.into_unoptimized_plan(); let mut fields: Vec = plan .schema() - .fields() + .columns() .iter() - .map(|f| col(f.qualified_column())) + .map(|f| col(f.clone())) .collect(); fields.extend(new_columns.into_iter().map(|(name, ex)| ex.alias(name))); diff --git a/crates/core/src/operations/transaction/mod.rs b/crates/core/src/operations/transaction/mod.rs index 31cbc3a33b..4783a4c03b 100644 --- a/crates/core/src/operations/transaction/mod.rs +++ b/crates/core/src/operations/transaction/mod.rs @@ -477,7 +477,10 @@ impl<'a> PreCommit<'a> { let log_entry = this.data.get_bytes()?; let token = uuid::Uuid::new_v4().to_string(); let path = Path::from_iter([DELTA_LOG_FOLDER, &format!("_commit_{token}.json.tmp")]); - this.log_store.object_store().put(&path, log_entry).await?; + this.log_store + .object_store() + .put(&path, log_entry.into()) + .await?; Ok(PreparedCommit { path, @@ -699,7 +702,7 @@ mod tests { logstore::{default_logstore::DefaultLogStore, LogStore}, storage::commit_uri_from_version, }; - use object_store::memory::InMemory; + use object_store::{memory::InMemory, PutPayload}; use url::Url; #[test] @@ -723,8 +726,8 @@ mod tests { ); let tmp_path = Path::from("_delta_log/tmp"); let version_path = Path::from("_delta_log/00000000000000000000.json"); - store.put(&tmp_path, bytes::Bytes::new()).await.unwrap(); - store.put(&version_path, bytes::Bytes::new()).await.unwrap(); + store.put(&tmp_path, PutPayload::new()).await.unwrap(); + store.put(&version_path, PutPayload::new()).await.unwrap(); let res = log_store.write_commit_entry(0, &tmp_path).await; // fails if file version already exists diff --git a/crates/core/src/operations/transaction/state.rs b/crates/core/src/operations/transaction/state.rs index e979cda363..d705a616b1 100644 --- a/crates/core/src/operations/transaction/state.rs +++ b/crates/core/src/operations/transaction/state.rs @@ -5,19 +5,17 @@ use arrow::array::{ArrayRef, BooleanArray}; use arrow::datatypes::{ DataType, Field as ArrowField, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, }; +use datafusion::execution::context::SessionContext; use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; use datafusion_common::scalar::ScalarValue; -use datafusion_common::Column; +use datafusion_common::{Column, ToDFSchema}; use datafusion_expr::Expr; use itertools::Itertools; use object_store::ObjectStore; use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::async_reader::{ParquetObjectReader, ParquetRecordBatchStreamBuilder}; -use crate::delta_datafusion::{ - get_null_of_arrow_type, logical_expr_to_physical_expr, to_correct_scalar_value, - DataFusionMixins, -}; +use crate::delta_datafusion::{get_null_of_arrow_type, to_correct_scalar_value, DataFusionMixins}; use crate::errors::DeltaResult; use crate::kernel::{Add, EagerSnapshot}; use crate::table::state::DeltaTableState; @@ -153,7 +151,9 @@ impl<'a> AddContainer<'a> { /// so evaluating expressions is inexact. However, excluded files are guaranteed (for a correct log) /// to not contain matches by the predicate expression. pub fn predicate_matches(&self, predicate: Expr) -> DeltaResult> { - let expr = logical_expr_to_physical_expr(predicate, &self.schema); + //let expr = logical_expr_to_physical_expr(predicate, &self.schema); + let expr = SessionContext::new() + .create_physical_expr(predicate, &self.schema.clone().to_dfschema()?)?; let pruning_predicate = PruningPredicate::try_new(expr, self.schema.clone())?; Ok(self .inner diff --git a/crates/core/src/operations/update.rs b/crates/core/src/operations/update.rs index 31946d104e..2f30f4aa8a 100644 --- a/crates/core/src/operations/update.rs +++ b/crates/core/src/operations/update.rs @@ -19,7 +19,8 @@ //! ```` use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, + iter, sync::Arc, time::{Instant, SystemTime, UNIX_EPOCH}, }; @@ -50,8 +51,8 @@ use super::{ }; use super::{transaction::PROTOCOL, write::WriterStatsConfig}; use crate::delta_datafusion::{ - create_physical_expr_fix, expr::fmt_expr_to_sql, physical::MetricObserverExec, - DataFusionMixins, DeltaColumn, DeltaSessionContext, + expr::fmt_expr_to_sql, physical::MetricObserverExec, DataFusionMixins, DeltaColumn, + DeltaSessionContext, }; use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder}; use crate::kernel::{Action, AddCDCFile, Remove}; @@ -206,15 +207,15 @@ async fn execute( None => None, }; - let updates: HashMap = updates + let updates = updates .into_iter() .map(|(key, expr)| match expr { - Expression::DataFusion(e) => Ok((key, e)), + Expression::DataFusion(e) => Ok((key.name, e)), Expression::String(s) => snapshot .parse_predicate_expression(s, &state) - .map(|e| (key, e)), + .map(|e| (key.name, e)), }) - .collect::, _>>()?; + .collect::, _>>()?; let current_metadata = snapshot.metadata(); let table_partition_cols = current_metadata.partition_columns.clone(); @@ -233,7 +234,6 @@ async fn execute( let input_schema = snapshot.input_schema()?; let tracker = CDCTracker::new(input_schema.clone()); - let execution_props = state.execution_props(); // For each rewrite evaluate the predicate and then modify each expression // to either compute the new value or obtain the old one then write these batches let scan = DeltaScanBuilder::new(&snapshot, log_store.clone(), &state) @@ -266,25 +266,30 @@ async fn execute( let input_schema = Arc::new(ArrowSchema::new(fields)); let input_dfschema: DFSchema = input_schema.as_ref().clone().try_into()?; - let mut expressions: Vec<(Arc, String)> = Vec::new(); - let scan_schema = scan.schema(); - for (i, field) in scan_schema.fields().into_iter().enumerate() { - expressions.push(( - Arc::new(expressions::Column::new(field.name(), i)), - field.name().to_owned(), - )); - } - // Take advantage of how null counts are tracked in arrow arrays use the // null count to track how many records do NOT statisfy the predicate. The // count is then exposed through the metrics through the `UpdateCountExec` // execution plan - let predicate_null = when(predicate.clone(), lit(true)).otherwise(lit(ScalarValue::Boolean(None)))?; - let predicate_expr = - create_physical_expr_fix(predicate_null, &input_dfschema, execution_props)?; - expressions.push((predicate_expr, UPDATE_PREDICATE_COLNAME.to_string())); + let update_predicate_expr = state.create_physical_expr(predicate_null, &input_dfschema)?; + + let expressions: Vec<(Arc, String)> = scan + .schema() + .fields() + .into_iter() + .enumerate() + .map(|(idx, field)| -> (Arc, String) { + ( + Arc::new(expressions::Column::new(field.name(), idx)), + field.name().to_owned(), + ) + }) + .chain(iter::once(( + update_predicate_expr, + UPDATE_PREDICATE_COLNAME.to_string(), + ))) + .collect(); let projection_predicate: Arc = Arc::new(ProjectionExec::try_new(expressions, scan.clone())?); @@ -307,66 +312,28 @@ async fn execute( }, )); - // Perform another projection but instead calculate updated values based on - // the predicate value. If the predicate is true then evalute the user - // provided expression otherwise return the original column value - // - // For each update column a new column with a name of __delta_rs_ + `original name` is created - let mut expressions: Vec<(Arc, String)> = Vec::new(); - let scan_schema = count_plan.schema(); - for (i, field) in scan_schema.fields().into_iter().enumerate() { - expressions.push(( - Arc::new(expressions::Column::new(field.name(), i)), - field.name().to_owned(), - )); - } - - // Maintain a map from the original column name to its temporary column index - let mut map = HashMap::::new(); - let mut control_columns = HashSet::::new(); - control_columns.insert(UPDATE_PREDICATE_COLNAME.to_string()); - - for (column, expr) in updates { - let expr = case(col(UPDATE_PREDICATE_COLNAME)) - .when(lit(true), expr.to_owned()) - .otherwise(col(column.to_owned()))?; - let predicate_expr = create_physical_expr_fix(expr, &input_dfschema, execution_props)?; - map.insert(column.name.clone(), expressions.len()); - let c = "__delta_rs_".to_string() + &column.name; - expressions.push((predicate_expr, c.clone())); - control_columns.insert(c); - } - - let projection_update: Arc = - Arc::new(ProjectionExec::try_new(expressions, count_plan.clone())?); - - // Project again to remove __delta_rs columns and rename update columns to their original name - let mut expressions: Vec<(Arc, String)> = Vec::new(); - let scan_schema = projection_update.schema(); - - for (i, field) in scan_schema.fields().into_iter().enumerate() { - if !control_columns.contains(field.name()) { - match map.get(field.name()) { - Some(value) => { - expressions.push(( - Arc::new(expressions::Column::new(field.name(), *value)), - field.name().to_owned(), - )); - } - None => { - expressions.push(( - Arc::new(expressions::Column::new(field.name(), i)), - field.name().to_owned(), - )); + let expressions: DeltaResult, String)>> = count_plan + .schema() + .fields() + .into_iter() + .enumerate() + .map(|(idx, field)| { + let field_name = field.name(); + let expr = match updates.get(field_name) { + Some(expr) => { + let expr = case(col(UPDATE_PREDICATE_COLNAME)) + .when(lit(true), expr.to_owned()) + .otherwise(col(Column::from_qualified_name_ignore_case(field_name)))?; + state.create_physical_expr(expr, &input_dfschema)? } - } - } - } + None => Arc::new(expressions::Column::new(field_name, idx)), + }; + Ok((expr, field_name.to_owned())) + }) + .collect(); - let projection: Arc = Arc::new(ProjectionExec::try_new( - expressions, - projection_update.clone(), - )?); + let projection: Arc = + Arc::new(ProjectionExec::try_new(expressions?, count_plan.clone())?); let writer_stats_config = WriterStatsConfig::new( snapshot.table_config().num_indexed_cols(), diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 1cdf2780bd..f4c6f36cf3 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -49,9 +49,7 @@ use super::writer::{DeltaWriter, WriterConfig}; use super::CreateBuilder; use crate::delta_datafusion::expr::fmt_expr_to_sql; use crate::delta_datafusion::expr::parse_predicate_expression; -use crate::delta_datafusion::{ - create_physical_expr_fix, find_files, register_store, DeltaScanBuilder, -}; +use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder}; use crate::delta_datafusion::{DataFusionMixins, DeltaDataChecker}; use crate::errors::{DeltaResult, DeltaTableError}; use crate::kernel::{Action, Add, Metadata, PartitionsExt, Remove, StructType}; @@ -522,8 +520,8 @@ async fn execute_non_empty_expr( // Apply the negation of the filter and rewrite files let negated_expression = Expr::Not(Box::new(Expr::IsTrue(Box::new(expression.clone())))); - let predicate_expr = - create_physical_expr_fix(negated_expression, &input_dfschema, state.execution_props())?; + let predicate_expr = state.create_physical_expr(negated_expression, &input_dfschema)?; + let filter: Arc = Arc::new(FilterExec::try_new(predicate_expr, scan.clone())?); diff --git a/crates/core/src/operations/writer.rs b/crates/core/src/operations/writer.rs index e5e6901608..5128611ffd 100644 --- a/crates/core/src/operations/writer.rs +++ b/crates/core/src/operations/writer.rs @@ -369,7 +369,7 @@ impl PartitionWriter { let file_size = buffer.len() as i64; // write file to object store - self.object_store.put(&path, buffer).await?; + self.object_store.put(&path, buffer.into()).await?; self.files_written.push( create_add( &self.config.partition_values, diff --git a/crates/core/src/protocol/checkpoints.rs b/crates/core/src/protocol/checkpoints.rs index 6bf19a81f5..f2625e49cf 100644 --- a/crates/core/src/protocol/checkpoints.rs +++ b/crates/core/src/protocol/checkpoints.rs @@ -170,14 +170,16 @@ pub async fn create_checkpoint_for( let object_store = log_store.object_store(); debug!("Writing checkpoint to {:?}.", checkpoint_path); - object_store.put(&checkpoint_path, parquet_bytes).await?; + object_store + .put(&checkpoint_path, parquet_bytes.into()) + .await?; let last_checkpoint_content: Value = serde_json::to_value(checkpoint)?; let last_checkpoint_content = bytes::Bytes::from(serde_json::to_vec(&last_checkpoint_content)?); debug!("Writing _last_checkpoint to {:?}.", last_checkpoint_path); object_store - .put(&last_checkpoint_path, last_checkpoint_content) + .put(&last_checkpoint_path, last_checkpoint_content.into()) .await?; Ok(()) diff --git a/crates/core/src/storage/file.rs b/crates/core/src/storage/file.rs index c63a00dae6..f7fa168127 100644 --- a/crates/core/src/storage/file.rs +++ b/crates/core/src/storage/file.rs @@ -6,12 +6,12 @@ use bytes::Bytes; use futures::stream::BoxStream; use object_store::{ local::LocalFileSystem, path::Path as ObjectStorePath, Error as ObjectStoreError, GetOptions, - GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutOptions, PutResult, + GetResult, ListResult, ObjectMeta, ObjectStore, PutOptions, PutResult, Result as ObjectStoreResult, }; +use object_store::{MultipartUpload, PutMultipartOpts, PutPayload}; use std::ops::Range; use std::sync::Arc; -use tokio::io::AsyncWrite; use url::Url; const STORE_NAME: &str = "DeltaLocalObjectStore"; @@ -166,14 +166,18 @@ impl std::fmt::Display for FileStorageBackend { #[async_trait::async_trait] impl ObjectStore for FileStorageBackend { - async fn put(&self, location: &ObjectStorePath, bytes: Bytes) -> ObjectStoreResult { + async fn put( + &self, + location: &ObjectStorePath, + bytes: PutPayload, + ) -> ObjectStoreResult { self.inner.put(location, bytes).await } async fn put_opts( &self, location: &ObjectStorePath, - bytes: Bytes, + bytes: PutPayload, options: PutOptions, ) -> ObjectStoreResult { self.inner.put_opts(location, bytes, options).await @@ -254,16 +258,16 @@ impl ObjectStore for FileStorageBackend { async fn put_multipart( &self, location: &ObjectStorePath, - ) -> ObjectStoreResult<(MultipartId, Box)> { + ) -> ObjectStoreResult> { self.inner.put_multipart(location).await } - async fn abort_multipart( + async fn put_multipart_opts( &self, location: &ObjectStorePath, - multipart_id: &MultipartId, - ) -> ObjectStoreResult<()> { - self.inner.abort_multipart(location, multipart_id).await + options: PutMultipartOpts, + ) -> ObjectStoreResult> { + self.inner.put_multipart_opts(location, options).await } } diff --git a/crates/core/src/storage/retry_ext.rs b/crates/core/src/storage/retry_ext.rs index 81a52f3ba3..b63c29a8ae 100644 --- a/crates/core/src/storage/retry_ext.rs +++ b/crates/core/src/storage/retry_ext.rs @@ -1,7 +1,6 @@ //! Retry extension for [`ObjectStore`] -use bytes::Bytes; -use object_store::{path::Path, Error, ObjectStore, PutResult, Result}; +use object_store::{path::Path, Error, ObjectStore, PutPayload, PutResult, Result}; use tracing::log::*; /// Retry extension for [`ObjectStore`] @@ -29,7 +28,7 @@ pub trait ObjectStoreRetryExt: ObjectStore { async fn put_with_retries( &self, location: &Path, - bytes: Bytes, + bytes: PutPayload, max_retries: usize, ) -> Result { let mut attempt_number = 1; diff --git a/crates/core/src/writer/json.rs b/crates/core/src/writer/json.rs index ab1ccac5f2..2cf7f6a950 100644 --- a/crates/core/src/writer/json.rs +++ b/crates/core/src/writer/json.rs @@ -363,7 +363,9 @@ impl DeltaWriter> for JsonWriter { let path = next_data_path(&prefix, 0, &uuid, &writer.writer_properties); let obj_bytes = Bytes::from(writer.buffer.to_vec()); let file_size = obj_bytes.len() as i64; - self.storage.put_with_retries(&path, obj_bytes, 15).await?; + self.storage + .put_with_retries(&path, obj_bytes.into(), 15) + .await?; actions.push(create_add( &writer.partition_values, diff --git a/crates/core/src/writer/record_batch.rs b/crates/core/src/writer/record_batch.rs index 9cdc6a4322..d99673c8cb 100644 --- a/crates/core/src/writer/record_batch.rs +++ b/crates/core/src/writer/record_batch.rs @@ -225,7 +225,9 @@ impl DeltaWriter for RecordBatchWriter { let path = next_data_path(&prefix, 0, &uuid, &writer.writer_properties); let obj_bytes = Bytes::from(writer.buffer.to_vec()); let file_size = obj_bytes.len() as i64; - self.storage.put_with_retries(&path, obj_bytes, 15).await?; + self.storage + .put_with_retries(&path, obj_bytes.into(), 15) + .await?; actions.push(create_add( &writer.partition_values, diff --git a/crates/core/tests/fs_common/mod.rs b/crates/core/tests/fs_common/mod.rs index e3d9e722e4..13683b408a 100644 --- a/crates/core/tests/fs_common/mod.rs +++ b/crates/core/tests/fs_common/mod.rs @@ -8,7 +8,9 @@ use deltalake_core::protocol::{DeltaOperation, SaveMode}; use deltalake_core::storage::{GetResult, ObjectStoreResult}; use deltalake_core::DeltaTable; use object_store::path::Path as StorePath; -use object_store::{ObjectStore, PutOptions, PutResult}; +use object_store::{ + MultipartUpload, ObjectStore, PutMultipartOpts, PutOptions, PutPayload, PutResult, +}; use serde_json::Value; use std::collections::HashMap; use std::fs; @@ -158,14 +160,14 @@ impl SlowStore { #[async_trait::async_trait] impl ObjectStore for SlowStore { /// Save the provided bytes to the specified location. - async fn put(&self, location: &StorePath, bytes: bytes::Bytes) -> ObjectStoreResult { + async fn put(&self, location: &StorePath, bytes: PutPayload) -> ObjectStoreResult { self.inner.put(location, bytes).await } async fn put_opts( &self, location: &StorePath, - bytes: bytes::Bytes, + bytes: PutPayload, options: PutOptions, ) -> ObjectStoreResult { self.inner.put_opts(location, bytes, options).await @@ -272,18 +274,15 @@ impl ObjectStore for SlowStore { async fn put_multipart( &self, location: &StorePath, - ) -> ObjectStoreResult<( - object_store::MultipartId, - Box, - )> { + ) -> ObjectStoreResult> { self.inner.put_multipart(location).await } - async fn abort_multipart( + async fn put_multipart_opts( &self, location: &StorePath, - multipart_id: &object_store::MultipartId, - ) -> ObjectStoreResult<()> { - self.inner.abort_multipart(location, multipart_id).await + options: PutMultipartOpts, + ) -> ObjectStoreResult> { + self.inner.put_multipart_opts(location, options).await } } diff --git a/crates/gcp/src/storage.rs b/crates/gcp/src/storage.rs index 9b938b737e..db02d33687 100644 --- a/crates/gcp/src/storage.rs +++ b/crates/gcp/src/storage.rs @@ -4,11 +4,11 @@ use bytes::Bytes; use deltalake_core::storage::ObjectStoreRef; use deltalake_core::Path; use futures::stream::BoxStream; +use object_store::{MultipartUpload, PutMultipartOpts, PutPayload}; use std::ops::Range; -use tokio::io::AsyncWrite; use deltalake_core::storage::object_store::{ - GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutOptions, PutResult, + GetOptions, GetResult, ListResult, ObjectMeta, ObjectStore, PutOptions, PutResult, Result as ObjectStoreResult, }; @@ -36,14 +36,14 @@ impl std::fmt::Display for GcsStorageBackend { #[async_trait::async_trait] impl ObjectStore for GcsStorageBackend { - async fn put(&self, location: &Path, bytes: Bytes) -> ObjectStoreResult { + async fn put(&self, location: &Path, bytes: PutPayload) -> ObjectStoreResult { self.inner.put(location, bytes).await } async fn put_opts( &self, location: &Path, - bytes: Bytes, + bytes: PutPayload, options: PutOptions, ) -> ObjectStoreResult { self.inner.put_opts(location, bytes, options).await @@ -120,18 +120,15 @@ impl ObjectStore for GcsStorageBackend { } } - async fn put_multipart( - &self, - location: &Path, - ) -> ObjectStoreResult<(MultipartId, Box)> { + async fn put_multipart(&self, location: &Path) -> ObjectStoreResult> { self.inner.put_multipart(location).await } - async fn abort_multipart( + async fn put_multipart_opts( &self, location: &Path, - multipart_id: &MultipartId, - ) -> ObjectStoreResult<()> { - self.inner.abort_multipart(location, multipart_id).await + options: PutMultipartOpts, + ) -> ObjectStoreResult> { + self.inner.put_multipart_opts(location, options).await } } diff --git a/crates/gcp/tests/context.rs b/crates/gcp/tests/context.rs index b96bd1f41b..5419075f68 100644 --- a/crates/gcp/tests/context.rs +++ b/crates/gcp/tests/context.rs @@ -39,7 +39,7 @@ pub async fn sync_stores( while let Some(file) = meta_stream.next().await { if let Ok(meta) = file { let bytes = from_store.get(&meta.location).await?.bytes().await?; - to_store.put(&meta.location, bytes).await?; + to_store.put(&meta.location, bytes.into()).await?; } } Ok(()) diff --git a/crates/mount/src/file.rs b/crates/mount/src/file.rs index 0169d1c8ce..29285a4a96 100644 --- a/crates/mount/src/file.rs +++ b/crates/mount/src/file.rs @@ -6,12 +6,12 @@ use bytes::Bytes; use futures::stream::BoxStream; use object_store::{ local::LocalFileSystem, path::Path as ObjectStorePath, Error as ObjectStoreError, GetOptions, - GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutOptions, PutResult, + GetResult, ListResult, ObjectMeta, ObjectStore, PutOptions, PutResult, Result as ObjectStoreResult, }; +use object_store::{MultipartUpload, PutMultipartOpts, PutPayload}; use std::ops::Range; use std::sync::Arc; -use tokio::io::AsyncWrite; use url::Url; pub(crate) const STORE_NAME: &str = "MountObjectStore"; @@ -156,14 +156,18 @@ impl std::fmt::Display for MountFileStorageBackend { #[async_trait::async_trait] impl ObjectStore for MountFileStorageBackend { - async fn put(&self, location: &ObjectStorePath, bytes: Bytes) -> ObjectStoreResult { + async fn put( + &self, + location: &ObjectStorePath, + bytes: PutPayload, + ) -> ObjectStoreResult { self.inner.put(location, bytes).await } async fn put_opts( &self, location: &ObjectStorePath, - bytes: Bytes, + bytes: PutPayload, options: PutOptions, ) -> ObjectStoreResult { self.inner.put_opts(location, bytes, options).await @@ -244,16 +248,16 @@ impl ObjectStore for MountFileStorageBackend { async fn put_multipart( &self, location: &ObjectStorePath, - ) -> ObjectStoreResult<(MultipartId, Box)> { + ) -> ObjectStoreResult> { self.inner.put_multipart(location).await } - async fn abort_multipart( + async fn put_multipart_opts( &self, location: &ObjectStorePath, - multipart_id: &MultipartId, - ) -> ObjectStoreResult<()> { - self.inner.abort_multipart(location, multipart_id).await + options: PutMultipartOpts, + ) -> ObjectStoreResult> { + self.inner.put_multipart_opts(location, options).await } } diff --git a/crates/sql/src/logical_plan.rs b/crates/sql/src/logical_plan.rs index 164462a90c..6e3c7d5dbc 100644 --- a/crates/sql/src/logical_plan.rs +++ b/crates/sql/src/logical_plan.rs @@ -1,7 +1,7 @@ use std::fmt::{self, Debug, Display}; use std::sync::Arc; -use datafusion_common::{DFSchema, DFSchemaRef, OwnedTableReference}; +use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, TableReference}; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::{Expr, UserDefinedLogicalNodeCore}; @@ -90,13 +90,31 @@ impl UserDefinedLogicalNodeCore for DeltaStatement { } fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + self.with_exprs_and_inputs(exprs.to_vec(), inputs.to_vec()) + .unwrap() + } + + fn with_exprs_and_inputs( + &self, + exprs: Vec, + inputs: Vec, + ) -> datafusion_common::Result { match self { Self::Vacuum(_) | Self::DescribeHistory(_) => { - assert_eq!(inputs.len(), 0, "input size inconsistent"); - assert_eq!(exprs.len(), 0, "expression size inconsistent"); - self.clone() + if !inputs.is_empty() { + return Err(DataFusionError::External("Input size inconsistent".into())); + } + if !exprs.is_empty() { + return Err(DataFusionError::External( + "Expression size inconsistent".into(), + )); + } + Ok(self.clone()) } - _ => todo!(), + _ => Err(DataFusionError::NotImplemented(format!( + "with_exprs_and_inputs not implemented for {:?}", + self + ))), } } } @@ -107,7 +125,7 @@ impl UserDefinedLogicalNodeCore for DeltaStatement { #[derive(Clone, PartialEq, Eq, Hash)] pub struct Vacuum { /// A reference to the table being vacuumed - pub table: OwnedTableReference, + pub table: TableReference, /// The retention threshold. pub retention_hours: Option, /// Return a list of up to 1000 files to be deleted. @@ -117,7 +135,7 @@ pub struct Vacuum { } impl Vacuum { - pub fn new(table: OwnedTableReference, retention_hours: Option, dry_run: bool) -> Self { + pub fn new(table: TableReference, retention_hours: Option, dry_run: bool) -> Self { Self { table, retention_hours, @@ -133,13 +151,13 @@ impl Vacuum { #[derive(Clone, PartialEq, Eq, Hash)] pub struct DescribeHistory { /// A reference to the table - pub table: OwnedTableReference, + pub table: TableReference, /// Schema for commit provenence information pub schema: DFSchemaRef, } impl DescribeHistory { - pub fn new(table: OwnedTableReference) -> Self { + pub fn new(table: TableReference) -> Self { Self { table, // TODO: add proper schema @@ -153,13 +171,13 @@ impl DescribeHistory { #[derive(Clone, PartialEq, Eq, Hash)] pub struct DescribeDetails { /// A reference to the table - pub table: OwnedTableReference, + pub table: TableReference, /// Schema for commit provenence information pub schema: DFSchemaRef, } impl DescribeDetails { - pub fn new(table: OwnedTableReference) -> Self { + pub fn new(table: TableReference) -> Self { Self { table, // TODO: add proper schema @@ -172,13 +190,13 @@ impl DescribeDetails { #[derive(Clone, PartialEq, Eq, Hash)] pub struct DescribeFiles { /// A reference to the table - pub table: OwnedTableReference, + pub table: TableReference, /// Schema for commit provenence information pub schema: DFSchemaRef, } impl DescribeFiles { - pub fn new(table: OwnedTableReference) -> Self { + pub fn new(table: TableReference) -> Self { Self { table, // TODO: add proper schema diff --git a/crates/sql/src/planner.rs b/crates/sql/src/planner.rs index 0be14d59b0..e2c76e68fd 100644 --- a/crates/sql/src/planner.rs +++ b/crates/sql/src/planner.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use datafusion_common::{OwnedTableReference, Result as DFResult}; +use datafusion_common::{Result as DFResult, TableReference}; use datafusion_expr::logical_plan::{Extension, LogicalPlan}; use datafusion_sql::planner::{ object_name_to_table_reference, ContextProvider, IdentNormalizer, ParserOptions, SqlToRel, @@ -54,7 +54,7 @@ impl<'a, S: ContextProvider> DeltaSqlToRel<'a, S> { fn vacuum_to_plan(&self, vacuum: VacuumStatement) -> DFResult { let table_ref = self.object_name_to_table_reference(vacuum.table)?; let plan = DeltaStatement::Vacuum(Vacuum::new( - table_ref.to_owned_reference(), + table_ref.clone(), vacuum.retention_hours, vacuum.dry_run, )); @@ -65,8 +65,7 @@ impl<'a, S: ContextProvider> DeltaSqlToRel<'a, S> { fn describe_to_plan(&self, describe: DescribeStatement) -> DFResult { let table_ref = self.object_name_to_table_reference(describe.table)?; - let plan = - DeltaStatement::DescribeFiles(DescribeFiles::new(table_ref.to_owned_reference())); + let plan = DeltaStatement::DescribeFiles(DescribeFiles::new(table_ref.clone())); Ok(LogicalPlan::Extension(Extension { node: Arc::new(plan), })) @@ -75,7 +74,7 @@ impl<'a, S: ContextProvider> DeltaSqlToRel<'a, S> { pub(crate) fn object_name_to_table_reference( &self, object_name: ObjectName, - ) -> DFResult { + ) -> DFResult { object_name_to_table_reference(object_name, self.options.enable_ident_normalization) } } @@ -122,10 +121,6 @@ mod tests { } impl ContextProvider for TestSchemaProvider { - fn get_table_provider(&self, name: TableReference) -> DFResult> { - self.get_table_source(name) - } - fn get_table_source(&self, name: TableReference) -> DFResult> { match self.tables.get(name.table()) { Some(table) => Ok(table.clone()), @@ -156,15 +151,15 @@ mod tests { None } - fn udfs_names(&self) -> Vec { + fn udf_names(&self) -> Vec { Vec::new() } - fn udafs_names(&self) -> Vec { + fn udaf_names(&self) -> Vec { Vec::new() } - fn udwfs_names(&self) -> Vec { + fn udwf_names(&self) -> Vec { Vec::new() } } diff --git a/crates/test/src/lib.rs b/crates/test/src/lib.rs index 0a1ca39539..c53d34b1d3 100644 --- a/crates/test/src/lib.rs +++ b/crates/test/src/lib.rs @@ -119,7 +119,7 @@ pub async fn add_file( commit_to_log: bool, ) { let backend = table.object_store(); - backend.put(path, data.clone()).await.unwrap(); + backend.put(path, data.clone().into()).await.unwrap(); if commit_to_log { let mut part_values = HashMap::new(); diff --git a/python/Cargo.toml b/python/Cargo.toml index 3938bd0aa9..672ba4ee50 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "deltalake-python" -version = "0.18.0" +version = "0.18.1" authors = ["Qingping Hou ", "Will Jones "] homepage = "https://github.com/delta-io/delta-rs" license = "Apache-2.0" @@ -43,8 +43,8 @@ reqwest = { version = "*", features = ["native-tls-vendored"] } deltalake-mount = { path = "../crates/mount" } [dependencies.pyo3] -version = "0.20" -features = ["extension-module", "abi3", "abi3-py38"] +version = "0.21.1" +features = ["extension-module", "abi3", "abi3-py38", "gil-refs"] [dependencies.deltalake] path = "../crates/deltalake" diff --git a/python/src/filesystem.rs b/python/src/filesystem.rs index 2825bf9092..af8410af72 100644 --- a/python/src/filesystem.rs +++ b/python/src/filesystem.rs @@ -1,7 +1,8 @@ use crate::error::PythonError; use crate::utils::{delete_dir, rt, walk_tree}; use crate::RawDeltaTable; -use deltalake::storage::{DynObjectStore, ListResult, MultipartId, ObjectStoreError, Path}; +use deltalake::storage::object_store::{MultipartUpload, PutPayloadMut}; +use deltalake::storage::{DynObjectStore, ListResult, ObjectStoreError, Path}; use deltalake::DeltaTableBuilder; use pyo3::exceptions::{PyIOError, PyNotImplementedError, PyValueError}; use pyo3::prelude::*; @@ -9,9 +10,8 @@ use pyo3::types::{IntoPyDict, PyBytes, PyType}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; -use tokio::io::{AsyncWrite, AsyncWriteExt}; -const DEFAULT_MAX_BUFFER_SIZE: i64 = 4 * 1024 * 1024; +const DEFAULT_MAX_BUFFER_SIZE: usize = 4 * 1024 * 1024; #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct FsConfig { @@ -292,7 +292,7 @@ impl DeltaFileSystemHandler { .options .get("max_buffer_size") .map_or(DEFAULT_MAX_BUFFER_SIZE, |v| { - v.parse::().unwrap_or(DEFAULT_MAX_BUFFER_SIZE) + v.parse::().unwrap_or(DEFAULT_MAX_BUFFER_SIZE) }); let file = rt() .block_on(ObjectOutputStream::try_new( @@ -489,39 +489,32 @@ impl ObjectInputFile { } // TODO the C++ implementation track an internal lock on all random access files, DO we need this here? -// TODO add buffer to store data ... #[pyclass(weakref, module = "deltalake._internal")] pub struct ObjectOutputStream { - store: Arc, - path: Path, - writer: Box, - multipart_id: MultipartId, + upload: Box, pos: i64, #[pyo3(get)] closed: bool, #[pyo3(get)] mode: String, - max_buffer_size: i64, - buffer_size: i64, + max_buffer_size: usize, + buffer: PutPayloadMut, } impl ObjectOutputStream { pub async fn try_new( store: Arc, path: Path, - max_buffer_size: i64, + max_buffer_size: usize, ) -> Result { - let (multipart_id, writer) = store.put_multipart(&path).await?; + let upload = store.put_multipart(&path).await?; Ok(Self { - store, - path, - writer, - multipart_id, + upload, pos: 0, closed: false, mode: "wb".into(), + buffer: PutPayloadMut::default(), max_buffer_size, - buffer_size: 0, }) } @@ -538,13 +531,12 @@ impl ObjectOutputStream { impl ObjectOutputStream { fn close(&mut self, py: Python<'_>) -> PyResult<()> { self.closed = true; - py.allow_threads(|| match rt().block_on(self.writer.shutdown()) { + if !self.buffer.is_empty() { + self.flush(py)?; + } + py.allow_threads(|| match rt().block_on(self.upload.complete()) { Ok(_) => Ok(()), - Err(err) => { - rt().block_on(self.store.abort_multipart(&self.path, &self.multipart_id)) - .map_err(PythonError::from)?; - Err(PyIOError::new_err(err.to_string())) - } + Err(err) => Err(PyIOError::new_err(err.to_string())), }) } @@ -590,31 +582,23 @@ impl ObjectOutputStream { fn write(&mut self, data: &PyBytes) -> PyResult { self.check_closed()?; - let len = data.as_bytes().len() as i64; let py = data.py(); - let data = data.as_bytes(); - let res = py.allow_threads(|| match rt().block_on(self.writer.write_all(data)) { - Ok(_) => Ok(len), - Err(err) => { - rt().block_on(self.store.abort_multipart(&self.path, &self.multipart_id)) - .map_err(PythonError::from)?; - Err(PyIOError::new_err(err.to_string())) - } - })?; - self.buffer_size += len; - if self.buffer_size >= self.max_buffer_size { - let _ = self.flush(py); - self.buffer_size = 0; + let bytes = data.as_bytes(); + let len = bytes.len(); + py.allow_threads(|| self.buffer.extend_from_slice(bytes)); + if self.buffer.content_length() >= self.max_buffer_size { + self.flush(py)?; } - Ok(res) + Ok(len as i64) } fn flush(&mut self, py: Python<'_>) -> PyResult<()> { - py.allow_threads(|| match rt().block_on(self.writer.flush()) { + let payload = std::mem::take(&mut self.buffer).freeze(); + py.allow_threads(|| match rt().block_on(self.upload.put_part(payload)) { Ok(_) => Ok(()), Err(err) => { - rt().block_on(self.store.abort_multipart(&self.path, &self.multipart_id)) - .map_err(PythonError::from)?; + rt().block_on(self.upload.abort()) + .map_err(|err| PythonError::from(err))?; Err(PyIOError::new_err(err.to_string())) } }) diff --git a/python/src/lib.rs b/python/src/lib.rs index 14cbf6f916..9d766a8dfb 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1,5 +1,3 @@ -#![deny(warnings)] - mod error; mod filesystem; mod schema; diff --git a/python/tests/test_update.py b/python/tests/test_update.py index 74ae130224..85e3fe38ec 100644 --- a/python/tests/test_update.py +++ b/python/tests/test_update.py @@ -119,7 +119,7 @@ def test_update_wrong_types_cast(tmp_path: pathlib.Path, sample_table: pa.Table) assert ( str(excinfo.value) - == "Cast error: Cannot cast value 'hello_world' to value of Boolean type" + == "Generic DeltaTable error: Error during planning: Failed to coerce then ([Utf8]) and else (Some(Boolean)) to common types in CASE WHEN expression" ) diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index fb41d55a09..169db15e80 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -1463,6 +1463,19 @@ def test_invalid_decimals(tmp_path: pathlib.Path, engine): write_deltalake(table_or_uri=tmp_path, mode="append", data=data, engine=engine) +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_write_large_decimal(tmp_path: pathlib.Path, engine): + data = pa.table( + { + "decimal_column": pa.array( + [Decimal(11111111111111111), Decimal(22222), Decimal("333333333333.33")] + ) + } + ) + + write_deltalake(tmp_path, data, engine=engine) + + def test_float_values(tmp_path: pathlib.Path): data = pa.table( {