From cd11c76a38e9254c65e660cc50a0f8592d679bd6 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Wed, 28 Feb 2024 16:39:10 +0100 Subject: [PATCH 01/40] compiles ;) --- crates/core/src/operations/delete.rs | 11 ++- crates/core/src/operations/merge/mod.rs | 6 +- crates/core/src/operations/update.rs | 6 +- crates/core/src/operations/write.rs | 114 ++++++++++++++++++------ 4 files changed, 102 insertions(+), 35 deletions(-) diff --git a/crates/core/src/operations/delete.rs b/crates/core/src/operations/delete.rs index 2e3e99bde2..7544a0159c 100644 --- a/crates/core/src/operations/delete.rs +++ b/crates/core/src/operations/delete.rs @@ -17,6 +17,7 @@ //! .await?; //! ```` +use core::panic; use std::collections::HashMap; use std::sync::Arc; use std::time::{Instant, SystemTime, UNIX_EPOCH}; @@ -36,6 +37,7 @@ use serde_json::Value; use super::datafusion_utils::Expression; use super::transaction::PROTOCOL; +use super::write::SchemaWriteMode; use crate::delta_datafusion::expr::fmt_expr_to_sql; use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder, DeltaSessionContext}; use crate::errors::DeltaResult; @@ -167,9 +169,14 @@ async fn excute_non_empty_expr( None, writer_properties, false, - false, + SchemaWriteMode::None, ) - .await?; + .await? + .into_iter().map(|a| match a { + Action::Add(a) => a, + _ => panic!("Expected Add action"), + + }).collect::>(); let read_records = scan.parquet_scan.metrics().and_then(|m| m.output_rows()); let filter_records = filter.metrics().and_then(|m| m.output_rows()); diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index 6190e8f724..6339ed4b8e 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -74,7 +74,7 @@ use crate::delta_datafusion::{ use crate::kernel::Action; use crate::logstore::LogStoreRef; use crate::operations::merge::barrier::find_barrier_node; -use crate::operations::write::write_execution_plan; +use crate::operations::write::{write_execution_plan, SchemaWriteMode}; use crate::protocol::{DeltaOperation, MergePredicate}; use crate::table::state::DeltaTableState; use crate::{DeltaResult, DeltaTable, DeltaTableError}; @@ -1379,13 +1379,13 @@ async fn execute( None, writer_properties, safe_cast, - false, + SchemaWriteMode::None, ) .await?; metrics.rewrite_time_ms = Instant::now().duration_since(rewrite_start).as_millis() as u64; - let mut actions: Vec = add_actions.into_iter().map(Action::Add).collect(); + let mut actions: Vec = add_actions.clone(); metrics.num_target_files_added = actions.len(); let survivors = barrier diff --git a/crates/core/src/operations/update.rs b/crates/core/src/operations/update.rs index d07f3f9fc0..957ca7e1b1 100644 --- a/crates/core/src/operations/update.rs +++ b/crates/core/src/operations/update.rs @@ -43,7 +43,7 @@ use parquet::file::properties::WriterProperties; use serde::Serialize; use serde_json::Value; -use super::datafusion_utils::Expression; +use super::{datafusion_utils::Expression, write::SchemaWriteMode}; use super::transaction::{commit, PROTOCOL}; use super::write::write_execution_plan; use crate::delta_datafusion::{ @@ -357,7 +357,7 @@ async fn execute( None, writer_properties, safe_cast, - false, + SchemaWriteMode::None, ) .await?; @@ -377,7 +377,7 @@ async fn execute( .duration_since(UNIX_EPOCH) .unwrap() .as_millis() as i64; - let mut actions: Vec = add_actions.into_iter().map(Action::Add).collect(); + let mut actions: Vec = add_actions.clone(); metrics.num_added_files = actions.len(); metrics.num_removed_files = candidates.candidates.len(); diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 73c1599a7e..5947aafb0c 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -25,12 +25,13 @@ //! ```` use std::collections::HashMap; +use std::str::FromStr; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; use arrow_array::RecordBatch; use arrow_cast::can_cast_types; -use arrow_schema::{DataType, Fields, SchemaRef as ArrowSchemaRef}; +use arrow_schema::{DataType, Fields, SchemaRef as ArrowSchemaRef, Schema as ArrowSchema}; use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion::physical_expr::create_physical_expr; use datafusion::physical_plan::filter::FilterExec; @@ -50,7 +51,7 @@ use crate::delta_datafusion::expr::parse_predicate_expression; use crate::delta_datafusion::DeltaDataChecker; use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder}; use crate::errors::{DeltaResult, DeltaTableError}; -use crate::kernel::{Action, Add, PartitionsExt, Remove, StructType}; +use crate::kernel::{Action, Add, Metadata, PartitionsExt, Remove, StructType}; use crate::logstore::LogStoreRef; use crate::protocol::{DeltaOperation, SaveMode}; use crate::storage::ObjectStoreRef; @@ -87,6 +88,35 @@ impl From for DeltaTableError { } } +///Specifies how to handle schema drifts +#[derive(PartialEq)] +pub enum SchemaWriteMode { + /// Use existing schema and fail if it does not match the new schema + None, + /// Overwrite the schema with the new schema + Overwrite, + /// Append the new schema to the existing schema + Merge, +} + + +impl FromStr for SchemaWriteMode { + type Err = DeltaTableError; + + fn from_str(s: &str) -> DeltaResult { + match s.to_ascii_lowercase().as_str() { + "none" => Ok(SchemaWriteMode::None), + "overwrite" => Ok(SchemaWriteMode::Overwrite), + "merge" => Ok(SchemaWriteMode::Merge), + _ => Err(DeltaTableError::Generic(format!( + "Invalid schema write mode provided: {}, only these are supported: ['none', 'overwrite', 'merge']", + s + ))), + } + } +} + + /// Write data into a DeltaTable pub struct WriteBuilder { /// A snapshot of the to-be-loaded table's state @@ -109,8 +139,8 @@ pub struct WriteBuilder { write_batch_size: Option, /// RecordBatches to be written into the table batches: Option>, - /// whether to overwrite the schema - overwrite_schema: bool, + /// whether to overwrite the schema or to merge it + schema_write_mode: SchemaWriteMode, /// how to handle cast failures, either return NULL (safe=true) or return ERR (safe=false) safe_cast: bool, /// Parquet writer properties @@ -140,7 +170,7 @@ impl WriteBuilder { write_batch_size: None, batches: None, safe_cast: false, - overwrite_schema: false, + schema_write_mode: SchemaWriteMode::None, writer_properties: None, app_metadata: None, name: None, @@ -155,9 +185,9 @@ impl WriteBuilder { self } - /// Add overwrite_schema - pub fn with_overwrite_schema(mut self, overwrite_schema: bool) -> Self { - self.overwrite_schema = overwrite_schema; + /// Add Schema Write Mode + pub fn with_schema_write_mode(mut self, schema_write_mode: SchemaWriteMode) -> Self { + self.schema_write_mode = schema_write_mode; self } @@ -311,12 +341,34 @@ async fn write_execution_plan_with_predicate( write_batch_size: Option, writer_properties: Option, safe_cast: bool, - overwrite_schema: bool, -) -> DeltaResult> { + schema_write_mode: SchemaWriteMode, +) -> DeltaResult> { + let mut schema_action: Option = None; // Use input schema to prevent wrapping partitions columns into a dictionary. - let schema: ArrowSchemaRef = if overwrite_schema { + let schema: ArrowSchemaRef = if schema_write_mode == SchemaWriteMode::Overwrite { plan.schema() - } else { + } + else if schema_write_mode == SchemaWriteMode::Merge { + let original_schema = snapshot + .and_then(|s| s.input_schema().ok()) + .unwrap_or(plan.schema()); + if original_schema == plan.schema() { + original_schema + } + else { + let new_schema= Arc::new(arrow_schema::Schema::try_merge(vec![original_schema.as_ref().clone(), plan.schema().as_ref().clone()])?); + let schema_struct: StructType = new_schema.clone().try_into()?; + schema_action = Some(Action::Metadata(Metadata::try_new(schema_struct, match snapshot { + Some(sn) => sn.metadata().partition_columns.clone(), + None => vec![], + }, match snapshot { + Some(sn) => sn.metadata().configuration.clone(), + None => HashMap::new(), + })?)); + new_schema.into() + } + } + else { snapshot .and_then(|s| s.input_schema().ok()) .unwrap_or(plan.schema()) @@ -352,7 +404,7 @@ async fn write_execution_plan_with_predicate( let mut writer = DeltaWriter::new(object_store.clone(), config); let checker_stream = checker.clone(); let mut stream = inner_plan.execute(i, task_ctx)?; - let handle: tokio::task::JoinHandle>> = + let handle: tokio::task::JoinHandle>> = tokio::task::spawn(async move { while let Some(maybe_batch) = stream.next().await { let batch = maybe_batch?; @@ -361,14 +413,16 @@ async fn write_execution_plan_with_predicate( super::cast::cast_record_batch(&batch, inner_schema.clone(), safe_cast)?; writer.write(&arr).await?; } - writer.close().await + let add_actions = writer.close().await; + match add_actions { + Ok(actions) => Ok(actions.into_iter().map(Action::Add).collect::>()), + Err(err) => Err(err.into()), + } }); tasks.push(handle); } - - // Collect add actions to add to commit - Ok(futures::future::join_all(tasks) + let mut actions = futures::future::join_all(tasks) .await .into_iter() .collect::, _>>() @@ -377,9 +431,15 @@ async fn write_execution_plan_with_predicate( .collect::, _>>()? .concat() .into_iter() - .collect::>()) + .collect::>(); + if schema_action.is_some() { + actions.push(schema_action.unwrap()); + } + // Collect add actions to add to commit + Ok(actions) } + #[allow(clippy::too_many_arguments)] pub(crate) async fn write_execution_plan( snapshot: Option<&DeltaTableState>, @@ -391,8 +451,8 @@ pub(crate) async fn write_execution_plan( write_batch_size: Option, writer_properties: Option, safe_cast: bool, - overwrite_schema: bool, -) -> DeltaResult> { + schema_write_mode: SchemaWriteMode, +) -> DeltaResult> { write_execution_plan_with_predicate( None, snapshot, @@ -404,7 +464,7 @@ pub(crate) async fn write_execution_plan( write_batch_size, writer_properties, safe_cast, - overwrite_schema, + schema_write_mode, ) .await } @@ -417,7 +477,7 @@ async fn execute_non_empty_expr( expression: &Expr, rewrite: &[Add], writer_properties: Option, -) -> DeltaResult> { +) -> DeltaResult> { // For each identified file perform a parquet scan + filter + limit (1) + count. // If returned count is not zero then append the file to be rewritten and removed from the log. Otherwise do nothing to the file. @@ -452,7 +512,7 @@ async fn execute_non_empty_expr( None, writer_properties, false, - false, + SchemaWriteMode::None, ) .await?; @@ -488,7 +548,7 @@ async fn prepare_predicate_actions( }; let remove = candidates.candidates; - let mut actions: Vec = add.into_iter().map(Action::Add).collect(); + let mut actions: Vec = add.into_iter().collect(); for action in remove { actions.push(Action::Remove(Remove { @@ -563,7 +623,7 @@ impl std::future::IntoFuture for WriteBuilder { .unwrap_or(schema.clone()); if !can_cast_batch(schema.fields(), table_schema.fields()) - && !(this.overwrite_schema && matches!(this.mode, SaveMode::Overwrite)) + && (this.schema_write_mode == SchemaWriteMode::None && !matches!(this.mode, SaveMode::Overwrite)) { return Err(DeltaTableError::Generic( "Schema of data does not match table schema".to_string(), @@ -641,10 +701,10 @@ impl std::future::IntoFuture for WriteBuilder { this.write_batch_size, this.writer_properties.clone(), this.safe_cast, - this.overwrite_schema, + this.schema_write_mode, ) .await?; - actions.extend(add_actions.into_iter().map(Action::Add)); + actions.extend(add_actions); // Collect remove actions if we are overwriting the table if let Some(snapshot) = &this.snapshot { From 3ce861a5abaf05308b594b97c5c849bf3fd8ac7a Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Wed, 28 Feb 2024 16:47:23 +0100 Subject: [PATCH 02/40] python compiles --- python/deltalake/_internal.pyi | 2 +- python/deltalake/writer.py | 20 +++++++++++++++----- python/src/lib.rs | 7 ++++--- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index e8994983f1..0489b397f3 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -174,7 +174,7 @@ def write_to_deltalake( partition_by: Optional[List[str]], mode: str, max_rows_per_group: int, - overwrite_schema: bool, + schema_write_mode: Optional[str], predicate: Optional[str], name: Optional[str], description: Optional[str], diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index df76ded806..761f1cbf11 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -49,7 +49,7 @@ convert_pyarrow_table, ) from .table import MAX_SUPPORTED_WRITER_VERSION, DeltaTable, WriterProperties - +import warnings try: import pandas as pd # noqa: F811 except ModuleNotFoundError: @@ -185,6 +185,7 @@ def write_deltalake( description: Optional[str] = None, configuration: Optional[Mapping[str, Optional[str]]] = None, overwrite_schema: bool = False, + schema_write_mode: Literal["none", "merge", "overwrite"] = "none", storage_options: Optional[Dict[str, str]] = None, partition_filters: Optional[List[Tuple[str, str, Any]]] = None, predicate: Optional[str] = None, @@ -238,7 +239,8 @@ def write_deltalake( name: User-provided identifier for this table. description: User-provided description for this table. configuration: A map containing configuration options for the metadata action. - overwrite_schema: If True, allows updating the schema of the table. + overwrite_schema: Deprecated, use schema_write_mode instead. + schema_write_mode: If set to "overwrite", allows replacing the schema of the table. Set to "merge" to merge with existing schema. storage_options: options passed to the native delta filesystem. predicate: When using `Overwrite` mode, replace data that matches a predicate. Only used in rust engine. partition_filters: the partition filters that will be used for partition overwrite. Only used in pyarrow engine. @@ -256,7 +258,15 @@ def write_deltalake( table.update_incremental() __enforce_append_only(table=table, configuration=configuration, mode=mode) - + if overwrite_schema: + assert schema_write_mode in ["none", "overwrite"] # none is default, overwrite would at least match + schema_write_mode = "overwrite" + + warnings.warn( + "overwrite_schema is deprecated, use schema_write_mode instead. ", + category=DeprecationWarning, + stacklevel=2, + ) if isinstance(partition_by, str): partition_by = [partition_by] @@ -302,7 +312,7 @@ def write_deltalake( partition_by=partition_by, mode=mode, max_rows_per_group=max_rows_per_group, - overwrite_schema=overwrite_schema, + schema_write_mode=schema_write_mode, predicate=predicate, name=name, description=description, @@ -327,7 +337,7 @@ def sort_arrow_schema(schema: pa.schema) -> pa.schema: if table: # already exists if sort_arrow_schema(schema) != sort_arrow_schema( table.schema().to_pyarrow(as_large_types=large_dtypes) - ) and not (mode == "overwrite" and overwrite_schema): + ) and not (mode == "overwrite" and schema_write_mode == "overwrite"): raise ValueError( "Schema of data does not match table schema\n" f"Data schema:\n{schema}\nTable Schema:\n{table.schema().to_pyarrow(as_large_types=large_dtypes)}" diff --git a/python/src/lib.rs b/python/src/lib.rs index 1992bae642..3b43e06876 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1367,7 +1367,7 @@ fn write_to_deltalake( data: PyArrowType, mode: String, max_rows_per_group: i64, - overwrite_schema: bool, + schema_write_mode: Option, partition_by: Option>, predicate: Option, name: Option, @@ -1390,9 +1390,10 @@ fn write_to_deltalake( let mut builder = table .write(batches) .with_save_mode(save_mode) - .with_overwrite_schema(overwrite_schema) .with_write_batch_size(max_rows_per_group as usize); - + if let Some(schema_write_mode) = schema_write_mode { + builder = builder.with_schema_write_mode(schema_write_mode.parse().map_err(PythonError::from)?); + } if let Some(partition_columns) = partition_by { builder = builder.with_partition_columns(partition_columns); } From 36d3463380a9625e0255c6a69604207df059210d Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Wed, 28 Feb 2024 16:55:47 +0100 Subject: [PATCH 03/40] fmt & clippy --- crates/core/src/operations/delete.rs | 7 ++-- crates/core/src/operations/update.rs | 2 +- crates/core/src/operations/write.rs | 48 +++++++++++++++------------- python/src/lib.rs | 3 +- 4 files changed, 32 insertions(+), 28 deletions(-) diff --git a/crates/core/src/operations/delete.rs b/crates/core/src/operations/delete.rs index 7544a0159c..f924da8266 100644 --- a/crates/core/src/operations/delete.rs +++ b/crates/core/src/operations/delete.rs @@ -172,11 +172,12 @@ async fn excute_non_empty_expr( SchemaWriteMode::None, ) .await? - .into_iter().map(|a| match a { + .into_iter() + .map(|a| match a { Action::Add(a) => a, _ => panic!("Expected Add action"), - - }).collect::>(); + }) + .collect::>(); let read_records = scan.parquet_scan.metrics().and_then(|m| m.output_rows()); let filter_records = filter.metrics().and_then(|m| m.output_rows()); diff --git a/crates/core/src/operations/update.rs b/crates/core/src/operations/update.rs index 957ca7e1b1..50a42dca13 100644 --- a/crates/core/src/operations/update.rs +++ b/crates/core/src/operations/update.rs @@ -43,9 +43,9 @@ use parquet::file::properties::WriterProperties; use serde::Serialize; use serde_json::Value; -use super::{datafusion_utils::Expression, write::SchemaWriteMode}; use super::transaction::{commit, PROTOCOL}; use super::write::write_execution_plan; +use super::{datafusion_utils::Expression, write::SchemaWriteMode}; use crate::delta_datafusion::{ expr::fmt_expr_to_sql, physical::MetricObserverExec, DeltaColumn, DeltaSessionContext, }; diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 5947aafb0c..85046acfb4 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -31,7 +31,7 @@ use std::time::{SystemTime, UNIX_EPOCH}; use arrow_array::RecordBatch; use arrow_cast::can_cast_types; -use arrow_schema::{DataType, Fields, SchemaRef as ArrowSchemaRef, Schema as ArrowSchema}; +use arrow_schema::{DataType, Fields, SchemaRef as ArrowSchemaRef}; use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion::physical_expr::create_physical_expr; use datafusion::physical_plan::filter::FilterExec; @@ -99,7 +99,6 @@ pub enum SchemaWriteMode { Merge, } - impl FromStr for SchemaWriteMode { type Err = DeltaTableError; @@ -116,7 +115,6 @@ impl FromStr for SchemaWriteMode { } } - /// Write data into a DeltaTable pub struct WriteBuilder { /// A snapshot of the to-be-loaded table's state @@ -347,28 +345,32 @@ async fn write_execution_plan_with_predicate( // Use input schema to prevent wrapping partitions columns into a dictionary. let schema: ArrowSchemaRef = if schema_write_mode == SchemaWriteMode::Overwrite { plan.schema() - } - else if schema_write_mode == SchemaWriteMode::Merge { + } else if schema_write_mode == SchemaWriteMode::Merge { let original_schema = snapshot .and_then(|s| s.input_schema().ok()) .unwrap_or(plan.schema()); if original_schema == plan.schema() { original_schema - } - else { - let new_schema= Arc::new(arrow_schema::Schema::try_merge(vec![original_schema.as_ref().clone(), plan.schema().as_ref().clone()])?); + } else { + let new_schema = Arc::new(arrow_schema::Schema::try_merge(vec![ + original_schema.as_ref().clone(), + plan.schema().as_ref().clone(), + ])?); let schema_struct: StructType = new_schema.clone().try_into()?; - schema_action = Some(Action::Metadata(Metadata::try_new(schema_struct, match snapshot { - Some(sn) => sn.metadata().partition_columns.clone(), - None => vec![], - }, match snapshot { - Some(sn) => sn.metadata().configuration.clone(), - None => HashMap::new(), - })?)); - new_schema.into() + schema_action = Some(Action::Metadata(Metadata::try_new( + schema_struct, + match snapshot { + Some(sn) => sn.metadata().partition_columns.clone(), + None => vec![], + }, + match snapshot { + Some(sn) => sn.metadata().configuration.clone(), + None => HashMap::new(), + }, + )?)); + new_schema } - } - else { + } else { snapshot .and_then(|s| s.input_schema().ok()) .unwrap_or(plan.schema()) @@ -416,7 +418,7 @@ async fn write_execution_plan_with_predicate( let add_actions = writer.close().await; match add_actions { Ok(actions) => Ok(actions.into_iter().map(Action::Add).collect::>()), - Err(err) => Err(err.into()), + Err(err) => Err(err), } }); @@ -432,14 +434,13 @@ async fn write_execution_plan_with_predicate( .concat() .into_iter() .collect::>(); - if schema_action.is_some() { - actions.push(schema_action.unwrap()); + if let Some(schema_action) = schema_action { + actions.push(schema_action); } // Collect add actions to add to commit Ok(actions) } - #[allow(clippy::too_many_arguments)] pub(crate) async fn write_execution_plan( snapshot: Option<&DeltaTableState>, @@ -623,7 +624,8 @@ impl std::future::IntoFuture for WriteBuilder { .unwrap_or(schema.clone()); if !can_cast_batch(schema.fields(), table_schema.fields()) - && (this.schema_write_mode == SchemaWriteMode::None && !matches!(this.mode, SaveMode::Overwrite)) + && (this.schema_write_mode == SchemaWriteMode::None + && !matches!(this.mode, SaveMode::Overwrite)) { return Err(DeltaTableError::Generic( "Schema of data does not match table schema".to_string(), diff --git a/python/src/lib.rs b/python/src/lib.rs index 3b43e06876..4e171d70ff 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1392,7 +1392,8 @@ fn write_to_deltalake( .with_save_mode(save_mode) .with_write_batch_size(max_rows_per_group as usize); if let Some(schema_write_mode) = schema_write_mode { - builder = builder.with_schema_write_mode(schema_write_mode.parse().map_err(PythonError::from)?); + builder = + builder.with_schema_write_mode(schema_write_mode.parse().map_err(PythonError::from)?); } if let Some(partition_columns) = partition_by { builder = builder.with_partition_columns(partition_columns); From 3c21940d5b869ddfda4d167469bf57bf16b89506 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Thu, 29 Feb 2024 13:59:39 +0100 Subject: [PATCH 04/40] renamings --- crates/core/src/operations/delete.rs | 4 +-- crates/core/src/operations/merge/mod.rs | 4 +-- crates/core/src/operations/update.rs | 4 +-- crates/core/src/operations/write.rs | 37 ++++++++++++------------- python/deltalake/_internal.pyi | 2 +- python/deltalake/writer.py | 18 ++++++------ python/src/lib.rs | 7 ++--- 7 files changed, 37 insertions(+), 39 deletions(-) diff --git a/crates/core/src/operations/delete.rs b/crates/core/src/operations/delete.rs index f924da8266..5cf65a848a 100644 --- a/crates/core/src/operations/delete.rs +++ b/crates/core/src/operations/delete.rs @@ -37,7 +37,7 @@ use serde_json::Value; use super::datafusion_utils::Expression; use super::transaction::PROTOCOL; -use super::write::SchemaWriteMode; +use super::write::SchemaMode; use crate::delta_datafusion::expr::fmt_expr_to_sql; use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder, DeltaSessionContext}; use crate::errors::DeltaResult; @@ -169,7 +169,7 @@ async fn excute_non_empty_expr( None, writer_properties, false, - SchemaWriteMode::None, + None, ) .await? .into_iter() diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index 6339ed4b8e..7bec9669d2 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -74,7 +74,7 @@ use crate::delta_datafusion::{ use crate::kernel::Action; use crate::logstore::LogStoreRef; use crate::operations::merge::barrier::find_barrier_node; -use crate::operations::write::{write_execution_plan, SchemaWriteMode}; +use crate::operations::write::{write_execution_plan, SchemaMode}; use crate::protocol::{DeltaOperation, MergePredicate}; use crate::table::state::DeltaTableState; use crate::{DeltaResult, DeltaTable, DeltaTableError}; @@ -1379,7 +1379,7 @@ async fn execute( None, writer_properties, safe_cast, - SchemaWriteMode::None, + None, ) .await?; diff --git a/crates/core/src/operations/update.rs b/crates/core/src/operations/update.rs index 50a42dca13..3bf3f90206 100644 --- a/crates/core/src/operations/update.rs +++ b/crates/core/src/operations/update.rs @@ -45,7 +45,7 @@ use serde_json::Value; use super::transaction::{commit, PROTOCOL}; use super::write::write_execution_plan; -use super::{datafusion_utils::Expression, write::SchemaWriteMode}; +use super::{datafusion_utils::Expression, write::SchemaMode}; use crate::delta_datafusion::{ expr::fmt_expr_to_sql, physical::MetricObserverExec, DeltaColumn, DeltaSessionContext, }; @@ -357,7 +357,7 @@ async fn execute( None, writer_properties, safe_cast, - SchemaWriteMode::None, + None, ) .await?; diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 85046acfb4..0a8c132052 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -90,23 +90,20 @@ impl From for DeltaTableError { ///Specifies how to handle schema drifts #[derive(PartialEq)] -pub enum SchemaWriteMode { - /// Use existing schema and fail if it does not match the new schema - None, +pub enum SchemaMode { /// Overwrite the schema with the new schema Overwrite, /// Append the new schema to the existing schema Merge, } -impl FromStr for SchemaWriteMode { +impl FromStr for SchemaMode { type Err = DeltaTableError; fn from_str(s: &str) -> DeltaResult { match s.to_ascii_lowercase().as_str() { - "none" => Ok(SchemaWriteMode::None), - "overwrite" => Ok(SchemaWriteMode::Overwrite), - "merge" => Ok(SchemaWriteMode::Merge), + "overwrite" => Ok(SchemaMode::Overwrite), + "merge" => Ok(SchemaMode::Merge), _ => Err(DeltaTableError::Generic(format!( "Invalid schema write mode provided: {}, only these are supported: ['none', 'overwrite', 'merge']", s @@ -137,8 +134,8 @@ pub struct WriteBuilder { write_batch_size: Option, /// RecordBatches to be written into the table batches: Option>, - /// whether to overwrite the schema or to merge it - schema_write_mode: SchemaWriteMode, + /// whether to overwrite the schema or to merge it. None means to fail on schmema drift + schema_mode: Option, /// how to handle cast failures, either return NULL (safe=true) or return ERR (safe=false) safe_cast: bool, /// Parquet writer properties @@ -168,7 +165,7 @@ impl WriteBuilder { write_batch_size: None, batches: None, safe_cast: false, - schema_write_mode: SchemaWriteMode::None, + schema_mode: None, writer_properties: None, app_metadata: None, name: None, @@ -184,8 +181,8 @@ impl WriteBuilder { } /// Add Schema Write Mode - pub fn with_schema_write_mode(mut self, schema_write_mode: SchemaWriteMode) -> Self { - self.schema_write_mode = schema_write_mode; + pub fn with_schema_mode(mut self, schema_mode: SchemaMode) -> Self { + self.schema_mode = Some(schema_mode); self } @@ -339,13 +336,13 @@ async fn write_execution_plan_with_predicate( write_batch_size: Option, writer_properties: Option, safe_cast: bool, - schema_write_mode: SchemaWriteMode, + schema_mode: Option, ) -> DeltaResult> { let mut schema_action: Option = None; // Use input schema to prevent wrapping partitions columns into a dictionary. - let schema: ArrowSchemaRef = if schema_write_mode == SchemaWriteMode::Overwrite { + let schema: ArrowSchemaRef = if schema_mode == Some(SchemaMode::Overwrite) { plan.schema() - } else if schema_write_mode == SchemaWriteMode::Merge { + } else if schema_mode == Some(SchemaMode::Merge) { let original_schema = snapshot .and_then(|s| s.input_schema().ok()) .unwrap_or(plan.schema()); @@ -452,7 +449,7 @@ pub(crate) async fn write_execution_plan( write_batch_size: Option, writer_properties: Option, safe_cast: bool, - schema_write_mode: SchemaWriteMode, + schema_mode: Option, ) -> DeltaResult> { write_execution_plan_with_predicate( None, @@ -465,7 +462,7 @@ pub(crate) async fn write_execution_plan( write_batch_size, writer_properties, safe_cast, - schema_write_mode, + schema_mode, ) .await } @@ -513,7 +510,7 @@ async fn execute_non_empty_expr( None, writer_properties, false, - SchemaWriteMode::None, + None, ) .await?; @@ -624,7 +621,7 @@ impl std::future::IntoFuture for WriteBuilder { .unwrap_or(schema.clone()); if !can_cast_batch(schema.fields(), table_schema.fields()) - && (this.schema_write_mode == SchemaWriteMode::None + && (this.schema_mode == None && !matches!(this.mode, SaveMode::Overwrite)) { return Err(DeltaTableError::Generic( @@ -703,7 +700,7 @@ impl std::future::IntoFuture for WriteBuilder { this.write_batch_size, this.writer_properties.clone(), this.safe_cast, - this.schema_write_mode, + this.schema_mode, ) .await?; actions.extend(add_actions); diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 0489b397f3..4300be52de 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -174,7 +174,7 @@ def write_to_deltalake( partition_by: Optional[List[str]], mode: str, max_rows_per_group: int, - schema_write_mode: Optional[str], + schema_mode: Optional[str], predicate: Optional[str], name: Optional[str], description: Optional[str], diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index 761f1cbf11..563d44a2f7 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -95,6 +95,7 @@ def write_deltalake( description: Optional[str] = ..., configuration: Optional[Mapping[str, Optional[str]]] = ..., overwrite_schema: bool = ..., + schema_mode: Optional[Literal["overwrite"]] = ..., storage_options: Optional[Dict[str, str]] = ..., partition_filters: Optional[List[Tuple[str, str, Any]]] = ..., large_dtypes: bool = ..., @@ -123,6 +124,7 @@ def write_deltalake( description: Optional[str] = ..., configuration: Optional[Mapping[str, Optional[str]]] = ..., overwrite_schema: bool = ..., + schema_mode: Optional[Literal[ "merge", "overwrite"]] = ..., storage_options: Optional[Dict[str, str]] = ..., large_dtypes: bool = ..., engine: Literal["rust"], @@ -151,6 +153,7 @@ def write_deltalake( description: Optional[str] = ..., configuration: Optional[Mapping[str, Optional[str]]] = ..., overwrite_schema: bool = ..., + schema_mode: Optional[Literal[ "merge", "overwrite"]] = ..., storage_options: Optional[Dict[str, str]] = ..., predicate: Optional[str] = ..., large_dtypes: bool = ..., @@ -185,7 +188,7 @@ def write_deltalake( description: Optional[str] = None, configuration: Optional[Mapping[str, Optional[str]]] = None, overwrite_schema: bool = False, - schema_write_mode: Literal["none", "merge", "overwrite"] = "none", + schema_mode: Optional[Literal[ "merge", "overwrite"]] = None, storage_options: Optional[Dict[str, str]] = None, partition_filters: Optional[List[Tuple[str, str, Any]]] = None, predicate: Optional[str] = None, @@ -239,8 +242,8 @@ def write_deltalake( name: User-provided identifier for this table. description: User-provided description for this table. configuration: A map containing configuration options for the metadata action. - overwrite_schema: Deprecated, use schema_write_mode instead. - schema_write_mode: If set to "overwrite", allows replacing the schema of the table. Set to "merge" to merge with existing schema. + overwrite_schema: Deprecated, use schema_mode instead. + schema_mode: If set to "overwrite", allows replacing the schema of the table. Set to "merge" to merge with existing schema. storage_options: options passed to the native delta filesystem. predicate: When using `Overwrite` mode, replace data that matches a predicate. Only used in rust engine. partition_filters: the partition filters that will be used for partition overwrite. Only used in pyarrow engine. @@ -259,11 +262,10 @@ def write_deltalake( __enforce_append_only(table=table, configuration=configuration, mode=mode) if overwrite_schema: - assert schema_write_mode in ["none", "overwrite"] # none is default, overwrite would at least match - schema_write_mode = "overwrite" + schema_mode = "overwrite" warnings.warn( - "overwrite_schema is deprecated, use schema_write_mode instead. ", + "overwrite_schema is deprecated, use schema_mode instead. ", category=DeprecationWarning, stacklevel=2, ) @@ -312,7 +314,7 @@ def write_deltalake( partition_by=partition_by, mode=mode, max_rows_per_group=max_rows_per_group, - schema_write_mode=schema_write_mode, + schema_mode=schema_mode, predicate=predicate, name=name, description=description, @@ -337,7 +339,7 @@ def sort_arrow_schema(schema: pa.schema) -> pa.schema: if table: # already exists if sort_arrow_schema(schema) != sort_arrow_schema( table.schema().to_pyarrow(as_large_types=large_dtypes) - ) and not (mode == "overwrite" and schema_write_mode == "overwrite"): + ) and not (mode == "overwrite" and schema_mode == "overwrite"): raise ValueError( "Schema of data does not match table schema\n" f"Data schema:\n{schema}\nTable Schema:\n{table.schema().to_pyarrow(as_large_types=large_dtypes)}" diff --git a/python/src/lib.rs b/python/src/lib.rs index 4e171d70ff..65d0ba5944 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1367,7 +1367,7 @@ fn write_to_deltalake( data: PyArrowType, mode: String, max_rows_per_group: i64, - schema_write_mode: Option, + schema_mode: Option, partition_by: Option>, predicate: Option, name: Option, @@ -1391,9 +1391,8 @@ fn write_to_deltalake( .write(batches) .with_save_mode(save_mode) .with_write_batch_size(max_rows_per_group as usize); - if let Some(schema_write_mode) = schema_write_mode { - builder = - builder.with_schema_write_mode(schema_write_mode.parse().map_err(PythonError::from)?); + if let Some(schema_mode) = schema_mode { + builder = builder.with_schema_mode(schema_mode.parse().map_err(PythonError::from)?); } if let Some(partition_columns) = partition_by { builder = builder.with_partition_columns(partition_columns); From 3c9ff118a9e32012607bf6e6a3ac1391e55f0a89 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Thu, 29 Feb 2024 14:03:59 +0100 Subject: [PATCH 05/40] clippy's feedback --- crates/core/src/operations/delete.rs | 1 - crates/core/src/operations/merge/mod.rs | 2 +- crates/core/src/operations/update.rs | 2 +- crates/core/src/operations/write.rs | 2 +- 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/crates/core/src/operations/delete.rs b/crates/core/src/operations/delete.rs index 5cf65a848a..1ab55310ea 100644 --- a/crates/core/src/operations/delete.rs +++ b/crates/core/src/operations/delete.rs @@ -37,7 +37,6 @@ use serde_json::Value; use super::datafusion_utils::Expression; use super::transaction::PROTOCOL; -use super::write::SchemaMode; use crate::delta_datafusion::expr::fmt_expr_to_sql; use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder, DeltaSessionContext}; use crate::errors::DeltaResult; diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index 7bec9669d2..e12f1e3dbb 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -74,7 +74,7 @@ use crate::delta_datafusion::{ use crate::kernel::Action; use crate::logstore::LogStoreRef; use crate::operations::merge::barrier::find_barrier_node; -use crate::operations::write::{write_execution_plan, SchemaMode}; +use crate::operations::write::write_execution_plan; use crate::protocol::{DeltaOperation, MergePredicate}; use crate::table::state::DeltaTableState; use crate::{DeltaResult, DeltaTable, DeltaTableError}; diff --git a/crates/core/src/operations/update.rs b/crates/core/src/operations/update.rs index 3bf3f90206..8663aeeac4 100644 --- a/crates/core/src/operations/update.rs +++ b/crates/core/src/operations/update.rs @@ -45,7 +45,7 @@ use serde_json::Value; use super::transaction::{commit, PROTOCOL}; use super::write::write_execution_plan; -use super::{datafusion_utils::Expression, write::SchemaMode}; +use super::datafusion_utils::Expression; use crate::delta_datafusion::{ expr::fmt_expr_to_sql, physical::MetricObserverExec, DeltaColumn, DeltaSessionContext, }; diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 0a8c132052..e18b2e4eaa 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -621,7 +621,7 @@ impl std::future::IntoFuture for WriteBuilder { .unwrap_or(schema.clone()); if !can_cast_batch(schema.fields(), table_schema.fields()) - && (this.schema_mode == None + && (this.schema_mode.is_none() && !matches!(this.mode, SaveMode::Overwrite)) { return Err(DeltaTableError::Generic( From 4d99d99acbbc6b88f3ea10bc810c086a336c03ab Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Fri, 1 Mar 2024 13:46:46 +0100 Subject: [PATCH 06/40] test seems ok --- crates/core/src/operations/cast.rs | 53 +++++++++++++-------- crates/core/src/operations/optimize.rs | 2 +- crates/core/src/operations/write.rs | 65 +++++++++++++++++++++++++- 3 files changed, 98 insertions(+), 22 deletions(-) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index 6e77552286..e87eea1590 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -1,6 +1,6 @@ //! Provide common cast functionality for callers //! -use arrow_array::{Array, ArrayRef, RecordBatch, StructArray}; +use arrow_array::{new_null_array, Array, ArrayRef, RecordBatch, StructArray}; use arrow_cast::{cast_with_options, CastOptions}; use arrow_schema::{DataType, Fields, SchemaRef as ArrowSchemaRef}; @@ -12,25 +12,40 @@ fn cast_struct( struct_array: &StructArray, fields: &Fields, cast_options: &CastOptions, + add_missing: bool, ) -> Result>, arrow_schema::ArrowError> { fields .iter() .map(|field| { - let col = struct_array.column_by_name(field.name()).unwrap(); - if let (DataType::Struct(_), DataType::Struct(child_fields)) = - (col.data_type(), field.data_type()) - { - let child_struct = StructArray::from(col.into_data()); - let s = cast_struct(&child_struct, child_fields, cast_options)?; - Ok(Arc::new(StructArray::new( - child_fields.clone(), - s, - child_struct.nulls().map(ToOwned::to_owned), - )) as ArrayRef) - } else if is_cast_required(col.data_type(), field.data_type()) { - cast_with_options(col, field.data_type(), cast_options) - } else { - Ok(col.clone()) + let col_or_not = struct_array.column_by_name(field.name()); + if col_or_not.is_none() { + if !add_missing { + return Err(arrow_schema::ArrowError::SchemaError(format!( + "Could not find column {0}", + field.name() + ))); + } + } + match col_or_not { + Some(col) => { + if let (DataType::Struct(_), DataType::Struct(child_fields)) = + (col.data_type(), field.data_type()) + { + let child_struct = StructArray::from(col.into_data()); + let s = + cast_struct(&child_struct, child_fields, cast_options, add_missing)?; + Ok(Arc::new(StructArray::new( + child_fields.clone(), + s, + child_struct.nulls().map(ToOwned::to_owned), + )) as ArrayRef) + } else if is_cast_required(col.data_type(), field.data_type()) { + cast_with_options(col, field.data_type(), cast_options) + } else { + Ok(col.clone()) + } + } + None => Ok(new_null_array(field.data_type(), struct_array.len())), } }) .collect::, _>>() @@ -51,6 +66,7 @@ pub fn cast_record_batch( batch: &RecordBatch, target_schema: ArrowSchemaRef, safe: bool, + add_missing: bool, ) -> DeltaResult { let cast_options = CastOptions { safe, @@ -62,8 +78,7 @@ pub fn cast_record_batch( batch.columns().to_owned(), None, ); - - let columns = cast_struct(&s, target_schema.fields(), &cast_options)?; + let columns = cast_struct(&s, target_schema.fields(), &cast_options, add_missing)?; Ok(RecordBatch::try_new(target_schema, columns)?) } @@ -93,7 +108,7 @@ mod tests { )]); let target_schema = Arc::new(Schema::new(fields)) as SchemaRef; - let result = cast_record_batch(&record_batch, target_schema, false); + let result = cast_record_batch(&record_batch, target_schema, false, false); let schema = result.unwrap().schema(); let field = schema.column_with_name("list_column").unwrap().1; diff --git a/crates/core/src/operations/optimize.rs b/crates/core/src/operations/optimize.rs index 990997399e..02c4ada546 100644 --- a/crates/core/src/operations/optimize.rs +++ b/crates/core/src/operations/optimize.rs @@ -458,7 +458,7 @@ impl MergePlan { let mut batch = maybe_batch?; batch = - super::cast::cast_record_batch(&batch, task_parameters.file_schema.clone(), false)?; + super::cast::cast_record_batch(&batch, task_parameters.file_schema.clone(), false, false)?; partial_metrics.num_batches += 1; writer.write(&batch).await.map_err(DeltaTableError::from)?; } diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index e18b2e4eaa..8107827718 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -89,7 +89,7 @@ impl From for DeltaTableError { } ///Specifies how to handle schema drifts -#[derive(PartialEq)] +#[derive(PartialEq, Clone, Copy)] pub enum SchemaMode { /// Overwrite the schema with the new schema Overwrite, @@ -409,7 +409,7 @@ async fn write_execution_plan_with_predicate( let batch = maybe_batch?; checker_stream.check_batch(&batch).await?; let arr = - super::cast::cast_record_batch(&batch, inner_schema.clone(), safe_cast)?; + super::cast::cast_record_batch(&batch, inner_schema.clone(), safe_cast, schema_mode == Some(SchemaMode::Merge))?; writer.write(&arr).await?; } let add_actions = writer.close().await; @@ -1056,6 +1056,67 @@ mod tests { assert_eq!(table.get_files_count(), 4) } + #[tokio::test] + async fn test_merge_schema() { + let batch = get_record_batch(None, false); + let table = DeltaOps::new_in_memory() + .write(vec![batch.clone()]) + .with_save_mode(SaveMode::ErrorIfExists) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); + for field in batch.schema().fields() { + if field.name() != "modified" { + new_schema_builder.push(field.clone()); + } + } + new_schema_builder.push(Field::new( + "inserted_by", + DataType::Utf8, + true, + )); + let new_schema = new_schema_builder.finish(); + let new_fields = new_schema.fields(); + let new_names = new_fields.iter().map(|f| f.name()).collect::>(); + assert_eq!(new_names, vec!["id", "value", "inserted_by"]); + let inserted_by = StringArray::from(vec![ + Some("A1"), + Some("B1"), + None, + Some("B2"), + Some("A3"), + Some("A4"), + None, + None, + Some("B4"), + Some("A5"), + Some("A7"), + ]); + let new_batch = RecordBatch::try_new( + Arc::new(new_schema), + vec![ + Arc::new(batch.column_by_name("id").unwrap().clone()), + Arc::new(batch.column_by_name("value").unwrap().clone()), + Arc::new(inserted_by), + ], + ).unwrap(); + + let table = DeltaOps(table) + .write(vec![new_batch]) + .with_save_mode(SaveMode::Append) + .with_schema_mode(SchemaMode::Merge) + .await + .unwrap(); + + assert_eq!(table.version(), 1); + let new_schema = table.metadata().unwrap().schema().unwrap(); + let fields = new_schema.fields(); + let names = fields.iter().map(|f| f.name()).collect::>(); + assert_eq!(names, vec!["id", "value", "modified", "inserted_by"]); + } + #[tokio::test] async fn test_check_invariants() { let batch = get_record_batch(None, false); From d1590207c215c70d5b0b3be67287938ff4e8de7c Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Fri, 1 Mar 2024 13:47:04 +0100 Subject: [PATCH 07/40] fmt --- crates/core/src/operations/optimize.rs | 8 ++++++-- crates/core/src/operations/update.rs | 2 +- crates/core/src/operations/write.rs | 19 ++++++++++--------- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/crates/core/src/operations/optimize.rs b/crates/core/src/operations/optimize.rs index 02c4ada546..90334e6de1 100644 --- a/crates/core/src/operations/optimize.rs +++ b/crates/core/src/operations/optimize.rs @@ -457,8 +457,12 @@ impl MergePlan { while let Some(maybe_batch) = read_stream.next().await { let mut batch = maybe_batch?; - batch = - super::cast::cast_record_batch(&batch, task_parameters.file_schema.clone(), false, false)?; + batch = super::cast::cast_record_batch( + &batch, + task_parameters.file_schema.clone(), + false, + false, + )?; partial_metrics.num_batches += 1; writer.write(&batch).await.map_err(DeltaTableError::from)?; } diff --git a/crates/core/src/operations/update.rs b/crates/core/src/operations/update.rs index 8663aeeac4..803b1d0312 100644 --- a/crates/core/src/operations/update.rs +++ b/crates/core/src/operations/update.rs @@ -43,9 +43,9 @@ use parquet::file::properties::WriterProperties; use serde::Serialize; use serde_json::Value; +use super::datafusion_utils::Expression; use super::transaction::{commit, PROTOCOL}; use super::write::write_execution_plan; -use super::datafusion_utils::Expression; use crate::delta_datafusion::{ expr::fmt_expr_to_sql, physical::MetricObserverExec, DeltaColumn, DeltaSessionContext, }; diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 8107827718..a7b0140b84 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -408,8 +408,12 @@ async fn write_execution_plan_with_predicate( while let Some(maybe_batch) = stream.next().await { let batch = maybe_batch?; checker_stream.check_batch(&batch).await?; - let arr = - super::cast::cast_record_batch(&batch, inner_schema.clone(), safe_cast, schema_mode == Some(SchemaMode::Merge))?; + let arr = super::cast::cast_record_batch( + &batch, + inner_schema.clone(), + safe_cast, + schema_mode == Some(SchemaMode::Merge), + )?; writer.write(&arr).await?; } let add_actions = writer.close().await; @@ -1072,11 +1076,7 @@ mod tests { new_schema_builder.push(field.clone()); } } - new_schema_builder.push(Field::new( - "inserted_by", - DataType::Utf8, - true, - )); + new_schema_builder.push(Field::new("inserted_by", DataType::Utf8, true)); let new_schema = new_schema_builder.finish(); let new_fields = new_schema.fields(); let new_names = new_fields.iter().map(|f| f.name()).collect::>(); @@ -1101,8 +1101,9 @@ mod tests { Arc::new(batch.column_by_name("value").unwrap().clone()), Arc::new(inserted_by), ], - ).unwrap(); - + ) + .unwrap(); + let table = DeltaOps(table) .write(vec![new_batch]) .with_save_mode(SaveMode::Append) From fd457d8482faf010dec05498f03f57422ba0cb40 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Fri, 1 Mar 2024 14:28:39 +0100 Subject: [PATCH 08/40] clippy feedback --- crates/core/src/operations/cast.rs | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index e87eea1590..e4c6679718 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -18,15 +18,15 @@ fn cast_struct( .iter() .map(|field| { let col_or_not = struct_array.column_by_name(field.name()); - if col_or_not.is_none() { - if !add_missing { - return Err(arrow_schema::ArrowError::SchemaError(format!( - "Could not find column {0}", - field.name() - ))); - } - } - match col_or_not { + match col_or_not { + None => + match add_missing { + true => Ok(new_null_array(field.data_type(), struct_array.len())), + false => Err(arrow_schema::ArrowError::SchemaError(format!( + "Could not find column {0}", + field.name() + ))), + } Some(col) => { if let (DataType::Struct(_), DataType::Struct(child_fields)) = (col.data_type(), field.data_type()) @@ -45,7 +45,6 @@ fn cast_struct( Ok(col.clone()) } } - None => Ok(new_null_array(field.data_type(), struct_array.len())), } }) .collect::, _>>() From a8a711c451a47d846b4a6169898e1eee7414dfbd Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Fri, 1 Mar 2024 16:12:29 +0100 Subject: [PATCH 09/40] compiles again after refactoring --- crates/core/src/operations/write.rs | 119 +++++++++++++++++++--------- 1 file changed, 80 insertions(+), 39 deletions(-) diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index a7b0140b84..b1ee521346 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -53,6 +53,7 @@ use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder}; use crate::errors::{DeltaResult, DeltaTableError}; use crate::kernel::{Action, Add, Metadata, PartitionsExt, Remove, StructType}; use crate::logstore::LogStoreRef; +use crate::operations::cast::cast_record_batch; use crate::protocol::{DeltaOperation, SaveMode}; use crate::storage::ObjectStoreRef; use crate::table::state::DeltaTableState; @@ -336,37 +337,10 @@ async fn write_execution_plan_with_predicate( write_batch_size: Option, writer_properties: Option, safe_cast: bool, - schema_mode: Option, + schema_mode: Option ) -> DeltaResult> { - let mut schema_action: Option = None; - // Use input schema to prevent wrapping partitions columns into a dictionary. let schema: ArrowSchemaRef = if schema_mode == Some(SchemaMode::Overwrite) { plan.schema() - } else if schema_mode == Some(SchemaMode::Merge) { - let original_schema = snapshot - .and_then(|s| s.input_schema().ok()) - .unwrap_or(plan.schema()); - if original_schema == plan.schema() { - original_schema - } else { - let new_schema = Arc::new(arrow_schema::Schema::try_merge(vec![ - original_schema.as_ref().clone(), - plan.schema().as_ref().clone(), - ])?); - let schema_struct: StructType = new_schema.clone().try_into()?; - schema_action = Some(Action::Metadata(Metadata::try_new( - schema_struct, - match snapshot { - Some(sn) => sn.metadata().partition_columns.clone(), - None => vec![], - }, - match snapshot { - Some(sn) => sn.metadata().configuration.clone(), - None => HashMap::new(), - }, - )?)); - new_schema - } } else { snapshot .and_then(|s| s.input_schema().ok()) @@ -435,9 +409,6 @@ async fn write_execution_plan_with_predicate( .concat() .into_iter() .collect::>(); - if let Some(schema_action) = schema_action { - actions.push(schema_action); - } // Collect add actions to add to commit Ok(actions) } @@ -453,7 +424,7 @@ pub(crate) async fn write_execution_plan( write_batch_size: Option, writer_properties: Option, safe_cast: bool, - schema_mode: Option, + schema_mode: Option ) -> DeltaResult> { write_execution_plan_with_predicate( None, @@ -466,7 +437,7 @@ pub(crate) async fn write_execution_plan( write_batch_size, writer_properties, safe_cast, - schema_mode, + schema_mode ) .await } @@ -608,22 +579,34 @@ impl std::future::IntoFuture for WriteBuilder { } else { Ok(this.partition_columns.unwrap_or_default()) }?; - + + let plan = if let Some(plan) = this.input { + if this.schema_mode == Some(SchemaMode::Merge) { + return Err(DeltaTableError::Generic( + "Schema merge not supported yet for Datafusion".to_string(), + )); + } Ok(plan) } else if let Some(batches) = this.batches { if batches.is_empty() { Err(WriteError::MissingData) } else { let schema = batches[0].schema(); - + + let mut new_schema = None; if let Some(snapshot) = &this.snapshot { let table_schema = snapshot .physical_arrow_schema(this.log_store.object_store().clone()) .await .or_else(|_| snapshot.arrow_schema()) .unwrap_or(schema.clone()); - + if this.schema_mode == Some(SchemaMode::Merge) { + new_schema = Some(Arc::new(arrow_schema::Schema::try_merge(vec![ + table_schema.as_ref().clone(), + schema.as_ref().clone(), + ])?)); + } if !can_cast_batch(schema.fields(), table_schema.fields()) && (this.schema_mode.is_none() && !matches!(this.mode, SaveMode::Overwrite)) @@ -638,10 +621,18 @@ impl std::future::IntoFuture for WriteBuilder { // TODO partitioning should probably happen in its own plan ... let mut partitions: HashMap> = HashMap::new(); for batch in batches { + let real_batch = match new_schema { + Some(ref new_schema) => { + cast_record_batch(&batch, new_schema.clone(), false, true)? + } + None => batch, + }; + + println!("before divide_by_partition_values {:?}", schema); let divided = divide_by_partition_values( schema.clone(), partition_columns.clone(), - &batch, + &real_batch, )?; for part in divided { let key = part.partition_values.hive_partition_path(); @@ -667,7 +658,26 @@ impl std::future::IntoFuture for WriteBuilder { Err(WriteError::MissingData) }?; let schema = plan.schema(); - + if this.schema_mode == Some(SchemaMode::Merge) || (this.schema_mode == Some(SchemaMode::Overwrite) && this.mode != SaveMode::Overwrite) + { + if let Some(snapshot) = &this.snapshot { + let table_schema = snapshot + .physical_arrow_schema(this.log_store.object_store().clone()) + .await + .or_else(|_| snapshot.arrow_schema()) + .unwrap_or(schema.clone()); + if !can_cast_batch(schema.fields(), table_schema.fields()) { + let schema_struct: StructType = schema.clone().try_into()?; + let schema_action = Action::Metadata(Metadata::try_new( + schema_struct, + partition_columns.clone(), + snapshot.metadata().configuration.clone(), + )?); + actions.push(schema_action); + } + } + } + println!("{:?}", schema); let state = match this.state { Some(state) => state, None => { @@ -704,7 +714,7 @@ impl std::future::IntoFuture for WriteBuilder { this.write_batch_size, this.writer_properties.clone(), this.safe_cast, - this.schema_mode, + this.schema_mode ) .await?; actions.extend(add_actions); @@ -1118,6 +1128,37 @@ mod tests { assert_eq!(names, vec!["id", "value", "modified", "inserted_by"]); } + #[tokio::test] + async fn test_merge_schema_with_missing_partitions() { + let batch = get_record_batch(None, false); + let table = DeltaOps::new_in_memory() + .write(vec![batch.clone()]) + .with_partition_columns(vec!["id", "value"]) + .with_save_mode(SaveMode::ErrorIfExists) + .await + .unwrap(); + assert_eq!(table.version(), 0); + let mut new_batch = batch.clone(); + new_batch.remove_column(0); + let new_schema = new_batch.schema().clone(); + let new_fields = new_schema.fields(); + let new_names = new_fields + .iter() + .map(|f| f.name().to_owned()) + .collect::>(); + assert_eq!(new_names, vec!["value", "modified"]); + println!("Merge now with {:?}", new_names); + let mut table = DeltaOps(table) + .write(vec![new_batch]) + .with_save_mode(SaveMode::Append) + .with_schema_mode(SchemaMode::Merge) + .await + .unwrap(); + table.load().await.unwrap(); + let part_cols = table.metadata().unwrap().partition_columns.clone(); + assert_eq!(part_cols, vec!["id", "value"]); // we want to preserve partitions even if null + } + #[tokio::test] async fn test_check_invariants() { let batch = get_record_batch(None, false); From f515f31d1111eb4ec90e1be2eb31a4f22d538c52 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Fri, 1 Mar 2024 16:12:38 +0100 Subject: [PATCH 10/40] fmt --- crates/core/src/operations/cast.rs | 17 ++++++++--------- crates/core/src/operations/write.rs | 23 ++++++++++++----------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index e4c6679718..87a7104be7 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -18,15 +18,14 @@ fn cast_struct( .iter() .map(|field| { let col_or_not = struct_array.column_by_name(field.name()); - match col_or_not { - None => - match add_missing { - true => Ok(new_null_array(field.data_type(), struct_array.len())), - false => Err(arrow_schema::ArrowError::SchemaError(format!( - "Could not find column {0}", - field.name() - ))), - } + match col_or_not { + None => match add_missing { + true => Ok(new_null_array(field.data_type(), struct_array.len())), + false => Err(arrow_schema::ArrowError::SchemaError(format!( + "Could not find column {0}", + field.name() + ))), + }, Some(col) => { if let (DataType::Struct(_), DataType::Struct(child_fields)) = (col.data_type(), field.data_type()) diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index b1ee521346..1576b17c33 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -337,7 +337,7 @@ async fn write_execution_plan_with_predicate( write_batch_size: Option, writer_properties: Option, safe_cast: bool, - schema_mode: Option + schema_mode: Option, ) -> DeltaResult> { let schema: ArrowSchemaRef = if schema_mode == Some(SchemaMode::Overwrite) { plan.schema() @@ -424,7 +424,7 @@ pub(crate) async fn write_execution_plan( write_batch_size: Option, writer_properties: Option, safe_cast: bool, - schema_mode: Option + schema_mode: Option, ) -> DeltaResult> { write_execution_plan_with_predicate( None, @@ -437,7 +437,7 @@ pub(crate) async fn write_execution_plan( write_batch_size, writer_properties, safe_cast, - schema_mode + schema_mode, ) .await } @@ -579,10 +579,9 @@ impl std::future::IntoFuture for WriteBuilder { } else { Ok(this.partition_columns.unwrap_or_default()) }?; - - + let plan = if let Some(plan) = this.input { - if this.schema_mode == Some(SchemaMode::Merge) { + if this.schema_mode == Some(SchemaMode::Merge) { return Err(DeltaTableError::Generic( "Schema merge not supported yet for Datafusion".to_string(), )); @@ -593,7 +592,7 @@ impl std::future::IntoFuture for WriteBuilder { Err(WriteError::MissingData) } else { let schema = batches[0].schema(); - + let mut new_schema = None; if let Some(snapshot) = &this.snapshot { let table_schema = snapshot @@ -601,7 +600,7 @@ impl std::future::IntoFuture for WriteBuilder { .await .or_else(|_| snapshot.arrow_schema()) .unwrap_or(schema.clone()); - if this.schema_mode == Some(SchemaMode::Merge) { + if this.schema_mode == Some(SchemaMode::Merge) { new_schema = Some(Arc::new(arrow_schema::Schema::try_merge(vec![ table_schema.as_ref().clone(), schema.as_ref().clone(), @@ -627,7 +626,7 @@ impl std::future::IntoFuture for WriteBuilder { } None => batch, }; - + println!("before divide_by_partition_values {:?}", schema); let divided = divide_by_partition_values( schema.clone(), @@ -658,7 +657,9 @@ impl std::future::IntoFuture for WriteBuilder { Err(WriteError::MissingData) }?; let schema = plan.schema(); - if this.schema_mode == Some(SchemaMode::Merge) || (this.schema_mode == Some(SchemaMode::Overwrite) && this.mode != SaveMode::Overwrite) + if this.schema_mode == Some(SchemaMode::Merge) + || (this.schema_mode == Some(SchemaMode::Overwrite) + && this.mode != SaveMode::Overwrite) { if let Some(snapshot) = &this.snapshot { let table_schema = snapshot @@ -714,7 +715,7 @@ impl std::future::IntoFuture for WriteBuilder { this.write_batch_size, this.writer_properties.clone(), this.safe_cast, - this.schema_mode + this.schema_mode, ) .await?; actions.extend(add_actions); From 6182cff46cf06b77f559fc28fe7d6fbbfad38dee Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Fri, 1 Mar 2024 16:13:17 +0100 Subject: [PATCH 11/40] clippy --- crates/core/src/operations/write.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 1576b17c33..2895d97421 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -399,7 +399,7 @@ async fn write_execution_plan_with_predicate( tasks.push(handle); } - let mut actions = futures::future::join_all(tasks) + let actions = futures::future::join_all(tasks) .await .into_iter() .collect::, _>>() From ca761a254d934d685a0976e00a59f9b2f74c058e Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Fri, 1 Mar 2024 17:21:52 +0100 Subject: [PATCH 12/40] wip on new merge method --- crates/core/src/operations/cast.rs | 20 ++++- crates/core/src/operations/write.rs | 117 +++++++++++++++++-------- crates/core/src/writer/record_batch.rs | 4 +- 3 files changed, 103 insertions(+), 38 deletions(-) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index 87a7104be7..7e6c7c3b14 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -2,12 +2,30 @@ //! use arrow_array::{new_null_array, Array, ArrayRef, RecordBatch, StructArray}; use arrow_cast::{cast_with_options, CastOptions}; -use arrow_schema::{DataType, Fields, SchemaRef as ArrowSchemaRef}; +use arrow_schema::{DataType, Fields, SchemaRef as ArrowSchemaRef, Schema as ArrowSchema}; use std::sync::Arc; use crate::DeltaResult; +fn merge_schema( + left: ArrowSchemaRef, + right: ArrowSchemaRef, +) -> DeltaResult { + let fields = left + .fields() + .iter() + .map(|field| { + let right_field = right.field_with_name(field.name()); + match right_field { + Some(right_field) => field.try_merge(right_field)?, + None => Ok(field.clone()) + } + }) + .collect(); + Ok(ArrowSchemaRef::new(ArrowSchema::new(fields))) +} + fn cast_struct( struct_array: &StructArray, fields: &Fields, diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 2895d97421..dd8ae3e1aa 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -53,7 +53,7 @@ use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder}; use crate::errors::{DeltaResult, DeltaTableError}; use crate::kernel::{Action, Add, Metadata, PartitionsExt, Remove, StructType}; use crate::logstore::LogStoreRef; -use crate::operations::cast::cast_record_batch; +use crate::operations::cast::{cast_record}_batch, merge_schema}; use crate::protocol::{DeltaOperation, SaveMode}; use crate::storage::ObjectStoreRef; use crate::table::state::DeltaTableState; @@ -600,36 +600,40 @@ impl std::future::IntoFuture for WriteBuilder { .await .or_else(|_| snapshot.arrow_schema()) .unwrap_or(schema.clone()); - if this.schema_mode == Some(SchemaMode::Merge) { - new_schema = Some(Arc::new(arrow_schema::Schema::try_merge(vec![ - table_schema.as_ref().clone(), - schema.as_ref().clone(), - ])?)); + + if !can_cast_batch(schema.fields(), table_schema.fields()) { + if this.mode == SaveMode::Overwrite { + new_schema = None // we overwrite anyway, so no need to cast + } else if this.schema_mode == Some(SchemaMode::Overwrite) { + new_schema = None // we overwrite anyway, so no need to cast + } else if this.schema_mode == Some(SchemaMode::Merge) { + println!("table. {:?} \r\n batch: {:?}", table_schema, schema); + new_schema = + Some(merge_schema( + table_schema.as_ref().clone(), + schema.as_ref().clone(), + )?)); + } else { + return Err(DeltaTableError::Generic( + "Schema of data does not match table schema".to_string(), + )); + } } - if !can_cast_batch(schema.fields(), table_schema.fields()) - && (this.schema_mode.is_none() - && !matches!(this.mode, SaveMode::Overwrite)) - { - return Err(DeltaTableError::Generic( - "Schema of data does not match table schema".to_string(), - )); - }; } let data = if !partition_columns.is_empty() { // TODO partitioning should probably happen in its own plan ... let mut partitions: HashMap> = HashMap::new(); for batch in batches { - let real_batch = match new_schema { - Some(ref new_schema) => { - cast_record_batch(&batch, new_schema.clone(), false, true)? + let real_batch = match new_schema.clone() { + Some(new_schema) => { + cast_record_batch(&batch, new_schema, false, true)? } None => batch, }; - println!("before divide_by_partition_values {:?}", schema); let divided = divide_by_partition_values( - schema.clone(), + new_schema.clone().unwrap_or(schema.clone()), partition_columns.clone(), &real_batch, )?; @@ -647,10 +651,24 @@ impl std::future::IntoFuture for WriteBuilder { } partitions.into_values().collect::>() } else { - vec![batches] + match new_schema { + Some(ref new_schema) => { + let mut new_batches = vec![]; + for batch in batches { + new_batches.push(cast_record_batch( + &batch, + new_schema.clone(), + false, + true, + )?); + } + vec![new_batches] + } + None => vec![batches], + } }; - Ok(Arc::new(MemoryExec::try_new(&data, schema.clone(), None)?) + Ok(Arc::new(MemoryExec::try_new(&data, new_schema.unwrap_or(schema).clone(), None)?) as Arc) } } else { @@ -678,7 +696,6 @@ impl std::future::IntoFuture for WriteBuilder { } } } - println!("{:?}", schema); let state = match this.state { Some(state) => state, None => { @@ -1130,7 +1147,7 @@ mod tests { } #[tokio::test] - async fn test_merge_schema_with_missing_partitions() { + async fn test_merge_schema_with_partitions() { let batch = get_record_batch(None, false); let table = DeltaOps::new_in_memory() .write(vec![batch.clone()]) @@ -1139,25 +1156,55 @@ mod tests { .await .unwrap(); assert_eq!(table.version(), 0); - let mut new_batch = batch.clone(); - new_batch.remove_column(0); - let new_schema = new_batch.schema().clone(); + + let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); + for field in batch.schema().fields() { + if field.name() != "modified" { + new_schema_builder.push(field.clone()); + } + } + new_schema_builder.push(Field::new("inserted_by", DataType::Utf8, true)); + let new_schema = new_schema_builder.finish(); let new_fields = new_schema.fields(); - let new_names = new_fields - .iter() - .map(|f| f.name().to_owned()) - .collect::>(); - assert_eq!(new_names, vec!["value", "modified"]); - println!("Merge now with {:?}", new_names); - let mut table = DeltaOps(table) + let new_names = new_fields.iter().map(|f| f.name()).collect::>(); + assert_eq!(new_names, vec!["id", "value", "inserted_by"]); + let inserted_by = StringArray::from(vec![ + Some("A1"), + Some("B1"), + None, + Some("B2"), + Some("A3"), + Some("A4"), + None, + None, + Some("B4"), + Some("A5"), + Some("A7"), + ]); + let new_batch = RecordBatch::try_new( + Arc::new(new_schema), + vec![ + Arc::new(batch.column_by_name("id").unwrap().clone()), + Arc::new(batch.column_by_name("value").unwrap().clone()), + Arc::new(inserted_by), + ], + ) + .unwrap(); + println!("new_batch: {:?}", new_batch.schema()); + let table = DeltaOps(table) .write(vec![new_batch]) .with_save_mode(SaveMode::Append) .with_schema_mode(SchemaMode::Merge) .await .unwrap(); - table.load().await.unwrap(); + + assert_eq!(table.version(), 1); + let new_schema = table.metadata().unwrap().schema().unwrap(); + let fields = new_schema.fields(); + let names = fields.iter().map(|f| f.name()).collect::>(); + assert_eq!(names, vec!["id", "value", "modified", "inserted_by"]); let part_cols = table.metadata().unwrap().partition_columns.clone(); - assert_eq!(part_cols, vec!["id", "value"]); // we want to preserve partitions even if null + assert_eq!(part_cols, vec!["id", "value"]); // we want to preserve partitions } #[tokio::test] diff --git a/crates/core/src/writer/record_batch.rs b/crates/core/src/writer/record_batch.rs index c62fc9b560..73a7472ef8 100644 --- a/crates/core/src/writer/record_batch.rs +++ b/crates/core/src/writer/record_batch.rs @@ -304,10 +304,10 @@ impl PartitionWriter { WriteMode::MergeSchema => { debug!("The writer and record batch schemas do not match, merging"); - let merged = ArrowSchema::try_merge(vec![ + let merged = merge_schema( self.arrow_schema.as_ref().clone(), record_batch.schema().as_ref().clone(), - ])?; + )?; self.arrow_schema = Arc::new(merged); let mut cols = vec![]; From 35027edddafe15c68af62113f235c20bebedb68e Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Fri, 1 Mar 2024 17:23:09 +0100 Subject: [PATCH 13/40] fmt --- crates/core/src/operations/cast.rs | 9 +++------ crates/core/src/operations/write.rs | 22 ++++++++++++---------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index 7e6c7c3b14..ada99b3b5c 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -2,16 +2,13 @@ //! use arrow_array::{new_null_array, Array, ArrayRef, RecordBatch, StructArray}; use arrow_cast::{cast_with_options, CastOptions}; -use arrow_schema::{DataType, Fields, SchemaRef as ArrowSchemaRef, Schema as ArrowSchema}; +use arrow_schema::{DataType, Fields, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; use std::sync::Arc; use crate::DeltaResult; -fn merge_schema( - left: ArrowSchemaRef, - right: ArrowSchemaRef, -) -> DeltaResult { +fn merge_schema(left: ArrowSchemaRef, right: ArrowSchemaRef) -> DeltaResult { let fields = left .fields() .iter() @@ -19,7 +16,7 @@ fn merge_schema( let right_field = right.field_with_name(field.name()); match right_field { Some(right_field) => field.try_merge(right_field)?, - None => Ok(field.clone()) + None => Ok(field.clone()), } }) .collect(); diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index dd8ae3e1aa..06e1fecace 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -53,7 +53,7 @@ use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder}; use crate::errors::{DeltaResult, DeltaTableError}; use crate::kernel::{Action, Add, Metadata, PartitionsExt, Remove, StructType}; use crate::logstore::LogStoreRef; -use crate::operations::cast::{cast_record}_batch, merge_schema}; +use crate::operations::cast::{cast_record_batch, merge_schema}; use crate::protocol::{DeltaOperation, SaveMode}; use crate::storage::ObjectStoreRef; use crate::table::state::DeltaTableState; @@ -608,11 +608,10 @@ impl std::future::IntoFuture for WriteBuilder { new_schema = None // we overwrite anyway, so no need to cast } else if this.schema_mode == Some(SchemaMode::Merge) { println!("table. {:?} \r\n batch: {:?}", table_schema, schema); - new_schema = - Some(merge_schema( - table_schema.as_ref().clone(), - schema.as_ref().clone(), - )?)); + new_schema = Some(merge_schema( + table_schema.as_ref().clone(), + schema.as_ref().clone(), + )?); } else { return Err(DeltaTableError::Generic( "Schema of data does not match table schema".to_string(), @@ -668,8 +667,11 @@ impl std::future::IntoFuture for WriteBuilder { } }; - Ok(Arc::new(MemoryExec::try_new(&data, new_schema.unwrap_or(schema).clone(), None)?) - as Arc) + Ok(Arc::new(MemoryExec::try_new( + &data, + new_schema.unwrap_or(schema).clone(), + None, + )?) as Arc) } } else { Err(WriteError::MissingData) @@ -1156,7 +1158,7 @@ mod tests { .await .unwrap(); assert_eq!(table.version(), 0); - + let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); for field in batch.schema().fields() { if field.name() != "modified" { @@ -1204,7 +1206,7 @@ mod tests { let names = fields.iter().map(|f| f.name()).collect::>(); assert_eq!(names, vec!["id", "value", "modified", "inserted_by"]); let part_cols = table.metadata().unwrap().partition_columns.clone(); - assert_eq!(part_cols, vec!["id", "value"]); // we want to preserve partitions + assert_eq!(part_cols, vec!["id", "value"]); // we want to preserve partitions } #[tokio::test] From d95889a458d51dccd8cee3d2c24bae4b2d2c4857 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Fri, 1 Mar 2024 17:23:53 +0100 Subject: [PATCH 14/40] next fix --- crates/core/src/operations/cast.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index ada99b3b5c..4f4a9cc941 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -8,7 +8,7 @@ use std::sync::Arc; use crate::DeltaResult; -fn merge_schema(left: ArrowSchemaRef, right: ArrowSchemaRef) -> DeltaResult { +pub (crate) fn merge_schema(left: ArrowSchemaRef, right: ArrowSchemaRef) -> DeltaResult { let fields = left .fields() .iter() From 36fa567c1bb3f12f9118c2237f71e36f5ec6a0b5 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Fri, 1 Mar 2024 17:28:31 +0100 Subject: [PATCH 15/40] WIP --- crates/core/src/operations/cast.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index 4f4a9cc941..c86dcf89d7 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -2,21 +2,21 @@ //! use arrow_array::{new_null_array, Array, ArrayRef, RecordBatch, StructArray}; use arrow_cast::{cast_with_options, CastOptions}; -use arrow_schema::{DataType, Fields, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; +use arrow_schema::{ArrowError, DataType, Fields, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; use std::sync::Arc; use crate::DeltaResult; -pub (crate) fn merge_schema(left: ArrowSchemaRef, right: ArrowSchemaRef) -> DeltaResult { +pub (crate) fn merge_schema(left: ArrowSchemaRef, right: ArrowSchemaRef) -> Result { let fields = left .fields() .iter() .map(|field| { let right_field = right.field_with_name(field.name()); match right_field { - Some(right_field) => field.try_merge(right_field)?, - None => Ok(field.clone()), + Ok(right_field) => field.try_merge(right_field)?, + _ => Ok(field.clone()), } }) .collect(); From 16023330e8637bb7e0599678c8d0ad3c638a5b19 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Fri, 1 Mar 2024 20:26:50 +0100 Subject: [PATCH 16/40] compiles again --- crates/core/src/operations/cast.rs | 26 +++++++++++++++++++------- crates/core/src/operations/write.rs | 4 ++-- crates/core/src/writer/record_batch.rs | 1 + 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index c86dcf89d7..e8f55f0fc2 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -2,25 +2,37 @@ //! use arrow_array::{new_null_array, Array, ArrayRef, RecordBatch, StructArray}; use arrow_cast::{cast_with_options, CastOptions}; -use arrow_schema::{ArrowError, DataType, Fields, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; +use arrow_schema::{ArrowError, Field as ArrowField, DataType, Fields, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; use std::sync::Arc; use crate::DeltaResult; -pub (crate) fn merge_schema(left: ArrowSchemaRef, right: ArrowSchemaRef) -> Result { - let fields = left +pub (crate) fn merge_schema(left: ArrowSchema, right: ArrowSchema) -> Result { + let left_fields: Result, ArrowError> = left .fields() .iter() .map(|field| { let right_field = right.field_with_name(field.name()); - match right_field { - Ok(right_field) => field.try_merge(right_field)?, - _ => Ok(field.clone()), + if let Ok(right_field) = right_field { + let mut new_field = field.as_ref().clone(); + new_field.try_merge(right_field)?; + Ok(new_field) + } + else { + Ok(field.as_ref().clone()) } + }) .collect(); - Ok(ArrowSchemaRef::new(ArrowSchema::new(fields))) + let mut fields = left_fields?; + for field in right.fields() { + if !left.field_with_name(field.name()).is_ok() { + fields.push(field.as_ref().clone()); + } + } + + Ok(ArrowSchema::new(fields)) } fn cast_struct( diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 06e1fecace..022dda6176 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -608,10 +608,10 @@ impl std::future::IntoFuture for WriteBuilder { new_schema = None // we overwrite anyway, so no need to cast } else if this.schema_mode == Some(SchemaMode::Merge) { println!("table. {:?} \r\n batch: {:?}", table_schema, schema); - new_schema = Some(merge_schema( + new_schema = Some(Arc::new(merge_schema( table_schema.as_ref().clone(), schema.as_ref().clone(), - )?); + )?)); } else { return Err(DeltaTableError::Generic( "Schema of data does not match table schema".to_string(), diff --git a/crates/core/src/writer/record_batch.rs b/crates/core/src/writer/record_batch.rs index 73a7472ef8..5f7515d9a8 100644 --- a/crates/core/src/writer/record_batch.rs +++ b/crates/core/src/writer/record_batch.rs @@ -29,6 +29,7 @@ use super::utils::{ use super::{DeltaWriter, DeltaWriterError, WriteMode}; use crate::errors::DeltaTableError; use crate::kernel::{Action, Add, PartitionsExt, Scalar, StructType}; +use crate::operations::cast::merge_schema; use crate::table::builder::DeltaTableBuilder; use crate::DeltaTable; From 563bf30dee0a1695b9fc607b179d576b0fd2b757 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Fri, 1 Mar 2024 20:27:55 +0100 Subject: [PATCH 17/40] fmt --- crates/core/src/operations/cast.rs | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index e8f55f0fc2..06c5065dd3 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -2,13 +2,19 @@ //! use arrow_array::{new_null_array, Array, ArrayRef, RecordBatch, StructArray}; use arrow_cast::{cast_with_options, CastOptions}; -use arrow_schema::{ArrowError, Field as ArrowField, DataType, Fields, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; +use arrow_schema::{ + ArrowError, DataType, Field as ArrowField, Fields, Schema as ArrowSchema, + SchemaRef as ArrowSchemaRef, +}; use std::sync::Arc; use crate::DeltaResult; -pub (crate) fn merge_schema(left: ArrowSchema, right: ArrowSchema) -> Result { +pub(crate) fn merge_schema( + left: ArrowSchema, + right: ArrowSchema, +) -> Result { let left_fields: Result, ArrowError> = left .fields() .iter() @@ -16,13 +22,11 @@ pub (crate) fn merge_schema(left: ArrowSchema, right: ArrowSchema) -> Result Date: Fri, 1 Mar 2024 20:43:27 +0100 Subject: [PATCH 18/40] might fixes test --- crates/core/src/operations/cast.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index 06c5065dd3..911c47196b 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -4,9 +4,9 @@ use arrow_array::{new_null_array, Array, ArrayRef, RecordBatch, StructArray}; use arrow_cast::{cast_with_options, CastOptions}; use arrow_schema::{ ArrowError, DataType, Field as ArrowField, Fields, Schema as ArrowSchema, - SchemaRef as ArrowSchemaRef, + SchemaRef as ArrowSchemaRef }; - +use arrow::datatypes::DataType::Dictionary; use std::sync::Arc; use crate::DeltaResult; @@ -21,8 +21,16 @@ pub(crate) fn merge_schema( .map(|field| { let right_field = right.field_with_name(field.name()); if let Ok(right_field) = right_field { + if let Dictionary(_, value_type) = right_field.data_type() { + if value_type.equals_datatype(field.data_type()) { + return Ok(field.as_ref().clone()); + } + } let mut new_field = field.as_ref().clone(); - new_field.try_merge(right_field)?; + let merge_res = new_field.try_merge(right_field); + if let Err(e) = merge_res { + return Err(e); + } Ok(new_field) } else { Ok(field.as_ref().clone()) From 0f7fba5254ce839314c04c31bed4e782015a8b5f Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Fri, 1 Mar 2024 20:47:04 +0100 Subject: [PATCH 19/40] better cast --- crates/core/src/operations/cast.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index 911c47196b..6e53c55e27 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -21,11 +21,17 @@ pub(crate) fn merge_schema( .map(|field| { let right_field = right.field_with_name(field.name()); if let Ok(right_field) = right_field { + // Allow Dictionary to be merged with non-Dictionary if let Dictionary(_, value_type) = right_field.data_type() { if value_type.equals_datatype(field.data_type()) { return Ok(field.as_ref().clone()); } } + if let Dictionary(_, value_type) = field.data_type() { + if value_type.equals_datatype(right_field.data_type()) { + return Ok(right_field.as_ref().clone()); + } + } let mut new_field = field.as_ref().clone(); let merge_res = new_field.try_merge(right_field); if let Err(e) = merge_res { From 950cd23b606718dbf257c4a5dd90784676af0fc5 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Fri, 1 Mar 2024 21:09:15 +0100 Subject: [PATCH 20/40] test passes! --- crates/core/src/operations/cast.rs | 2 +- crates/core/src/operations/write.rs | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index 6e53c55e27..79fcb861fc 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -29,7 +29,7 @@ pub(crate) fn merge_schema( } if let Dictionary(_, value_type) = field.data_type() { if value_type.equals_datatype(right_field.data_type()) { - return Ok(right_field.as_ref().clone()); + return Ok(right_field.clone()); } } let mut new_field = field.as_ref().clone(); diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 022dda6176..d838cc322f 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -1203,8 +1203,9 @@ mod tests { assert_eq!(table.version(), 1); let new_schema = table.metadata().unwrap().schema().unwrap(); let fields = new_schema.fields(); - let names = fields.iter().map(|f| f.name()).collect::>(); - assert_eq!(names, vec!["id", "value", "modified", "inserted_by"]); + let mut names = fields.iter().map(|f| f.name()).collect::>(); + names.sort(); + assert_eq!(names, vec!["id", "inserted_by", "modified", "value"]); let part_cols = table.metadata().unwrap().partition_columns.clone(); assert_eq!(part_cols, vec!["id", "value"]); // we want to preserve partitions } From 3292de08e67716a110197e0ce1fefde697fa6000 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Mon, 4 Mar 2024 10:22:28 +0100 Subject: [PATCH 21/40] tests passing in both rust and python --- crates/core/src/operations/cast.rs | 49 ++++++++++++------- crates/core/src/operations/write.rs | 76 +++++++++++++++++++++++++++-- python/deltalake/writer.py | 2 + python/tests/test_writer.py | 58 ++++++++++++++++++++++ 4 files changed, 163 insertions(+), 22 deletions(-) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index 79fcb861fc..57ca01ba39 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -11,6 +11,36 @@ use std::sync::Arc; use crate::DeltaResult; +pub (crate) fn merge_field(left: &ArrowField, right: &ArrowField) -> Result { + if let Dictionary(_, value_type) = right.data_type() { + if value_type.equals_datatype(left.data_type()) { + return Ok(left.clone()); + } + } + if let Dictionary(_, value_type) = left.data_type() { + if value_type.equals_datatype(right.data_type()) { + return Ok(right.clone()); + } + } + let mut new_field = left.clone(); + let merge_res = new_field.try_merge(right); + if let Err(e) = merge_res { + return Err(e); + } + Ok(new_field) +} + +pub(crate) fn is_compatible_for_merge(schema: ArrowSchema, other: ArrowSchema) -> Result<(), ArrowError> { + for f in schema.fields() { + if let Ok(other_field) = other.field_with_name(f.name()) { + if let Err(e) = merge_field(f.as_ref(), other_field) { + return Err(e); + } + } + } + Ok(()) +} + pub(crate) fn merge_schema( left: ArrowSchema, right: ArrowSchema, @@ -21,23 +51,7 @@ pub(crate) fn merge_schema( .map(|field| { let right_field = right.field_with_name(field.name()); if let Ok(right_field) = right_field { - // Allow Dictionary to be merged with non-Dictionary - if let Dictionary(_, value_type) = right_field.data_type() { - if value_type.equals_datatype(field.data_type()) { - return Ok(field.as_ref().clone()); - } - } - if let Dictionary(_, value_type) = field.data_type() { - if value_type.equals_datatype(right_field.data_type()) { - return Ok(right_field.clone()); - } - } - let mut new_field = field.as_ref().clone(); - let merge_res = new_field.try_merge(right_field); - if let Err(e) = merge_res { - return Err(e); - } - Ok(new_field) + merge_field(field.as_ref(), right_field) } else { Ok(field.as_ref().clone()) } @@ -53,6 +67,7 @@ pub(crate) fn merge_schema( Ok(ArrowSchema::new(fields)) } + fn cast_struct( struct_array: &StructArray, fields: &Fields, diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index d838cc322f..438c00a341 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -28,6 +28,7 @@ use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; +use std::vec; use arrow_array::RecordBatch; use arrow_cast::can_cast_types; @@ -53,7 +54,7 @@ use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder}; use crate::errors::{DeltaResult, DeltaTableError}; use crate::kernel::{Action, Add, Metadata, PartitionsExt, Remove, StructType}; use crate::logstore::LogStoreRef; -use crate::operations::cast::{cast_record_batch, merge_schema}; +use crate::operations::cast::{cast_record_batch, merge_schema, is_compatible_for_merge}; use crate::protocol::{DeltaOperation, SaveMode}; use crate::storage::ObjectStoreRef; use crate::table::state::DeltaTableState; @@ -339,7 +340,7 @@ async fn write_execution_plan_with_predicate( safe_cast: bool, schema_mode: Option, ) -> DeltaResult> { - let schema: ArrowSchemaRef = if schema_mode == Some(SchemaMode::Overwrite) { + let schema: ArrowSchemaRef = if let Some(_) = schema_mode { plan.schema() } else { snapshot @@ -605,9 +606,14 @@ impl std::future::IntoFuture for WriteBuilder { if this.mode == SaveMode::Overwrite { new_schema = None // we overwrite anyway, so no need to cast } else if this.schema_mode == Some(SchemaMode::Overwrite) { + if let Err(err) = is_compatible_for_merge( + table_schema.as_ref().clone(), + schema.as_ref().clone(), + ) { + return Err(DeltaTableError::InvalidData { violations: vec!(format!("{:?}", err)) }); + } new_schema = None // we overwrite anyway, so no need to cast } else if this.schema_mode == Some(SchemaMode::Merge) { - println!("table. {:?} \r\n batch: {:?}", table_schema, schema); new_schema = Some(Arc::new(merge_schema( table_schema.as_ref().clone(), schema.as_ref().clone(), @@ -1134,13 +1140,13 @@ mod tests { ) .unwrap(); - let table = DeltaOps(table) + let mut table = DeltaOps(table) .write(vec![new_batch]) .with_save_mode(SaveMode::Append) .with_schema_mode(SchemaMode::Merge) .await .unwrap(); - + table.load().await.unwrap(); assert_eq!(table.version(), 1); let new_schema = table.metadata().unwrap().schema().unwrap(); let fields = new_schema.fields(); @@ -1210,6 +1216,66 @@ mod tests { assert_eq!(part_cols, vec!["id", "value"]); // we want to preserve partitions } + + + #[tokio::test] + async fn test_overwrite_schema() { + let batch = get_record_batch(None, false); + let table = DeltaOps::new_in_memory() + .write(vec![batch.clone()]) + .with_save_mode(SaveMode::ErrorIfExists) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); + for field in batch.schema().fields() { + if field.name() != "modified" { + new_schema_builder.push(field.clone()); + } + } + new_schema_builder.push(Field::new("inserted_by", DataType::Utf8, true)); + let new_schema = new_schema_builder.finish(); + let new_fields = new_schema.fields(); + let new_names = new_fields.iter().map(|f| f.name()).collect::>(); + assert_eq!(new_names, vec!["id", "value", "inserted_by"]); + let inserted_by = StringArray::from(vec![ + Some("A1"), + Some("B1"), + None, + Some("B2"), + Some("A3"), + Some("A4"), + None, + None, + Some("B4"), + Some("A5"), + Some("A7"), + ]); + let new_batch = RecordBatch::try_new( + Arc::new(new_schema), + vec![ + Arc::new(batch.column_by_name("id").unwrap().clone()), + Arc::new(batch.column_by_name("value").unwrap().clone()), + Arc::new(inserted_by), + ], + ) + .unwrap(); + + let table = DeltaOps(table) + .write(vec![new_batch]) + .with_save_mode(SaveMode::Append) + .with_schema_mode(SchemaMode::Overwrite) + .await + .unwrap(); + + assert_eq!(table.version(), 1); + let new_schema = table.metadata().unwrap().schema().unwrap(); + let fields = new_schema.fields(); + let names = fields.iter().map(|f| f.name()).collect::>(); + assert_eq!(names, vec!["id", "value", "inserted_by"]); + } + #[tokio::test] async fn test_check_invariants() { let batch = get_record_batch(None, false); diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index 563d44a2f7..2b9381dd37 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -329,6 +329,8 @@ def write_deltalake( table.update_incremental() elif engine == "pyarrow": + if schema_mode == "merge": + raise ValueError("schema_mode 'merge' is not supported in pyarrow engine. Use engine=rust") # We need to write against the latest table version filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options)) diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 550fec71ee..c662814e14 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -145,6 +145,64 @@ def test_update_schema(existing_table: DeltaTable): assert existing_table.schema().to_pyarrow() == new_data.schema +def test_merge_schema(existing_table: DeltaTable): + print(existing_table._table.table_uri()) + old_table_data = existing_table.to_pyarrow_table() + new_data = pa.table({"new_x": pa.array([1, 2, 3], pa.int32()), "new_y": pa.array([1, 2, 3], pa.int32())}) + + write_deltalake(existing_table, new_data, mode="append", schema_mode="merge", engine="rust") + # adjust schema of old_table_data and new_data to match each other + + for i in range(old_table_data.num_columns): + col = old_table_data.schema.field(i) + new_data = new_data.add_column(i, col, pa.nulls(new_data.num_rows, col.type)) + + old_table_data=old_table_data.append_column(pa.field("new_x", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32())) + old_table_data=old_table_data.append_column(pa.field("new_y", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32())) + + # define sort order + read_data = existing_table.to_pyarrow_table().sort_by([("utf8", "ascending"), ("new_x", "ascending")]) + print(repr(read_data.to_pylist())) + concated = pa.concat_tables([old_table_data, new_data], promote_options="permissive") + print(repr(concated.to_pylist())) + assert read_data == concated + + write_deltalake(existing_table, new_data, mode="overwrite", overwrite_schema=True) + + assert existing_table.schema().to_pyarrow() == new_data.schema + + +def test_overwrite_schema(existing_table: DeltaTable): + print(existing_table._table.table_uri()) + old_table_data = existing_table.to_pyarrow_table() + new_data = pa.table({"utf8": pa.array(['bla', 'bli', 'blubb']), "new_x": pa.array([1, 2, 3], pa.int32()), "new_y": pa.array([1, 2, 3], pa.int32())}) + + write_deltalake(existing_table, new_data, mode="append", schema_mode="overwrite", engine="rust") + # adjust schema of old_table_data and new_data to match each other + old_table_data = old_table_data.select(["utf8"]) + old_table_data=old_table_data.append_column(pa.field("new_x", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32())) + old_table_data=old_table_data.append_column(pa.field("new_y", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32())) + + # define sort order + read_data = existing_table.to_pyarrow_table().sort_by([("utf8", "ascending"), ("new_x", "ascending")]) + print(repr(read_data.to_pylist())) + concated = pa.concat_tables([old_table_data, new_data], promote_options="permissive") + print(repr(concated.to_pylist())) + assert read_data == concated + + write_deltalake(existing_table, new_data, mode="overwrite", overwrite_schema=True) + + assert existing_table.schema().to_pyarrow() == new_data.schema + +def test_overwrite_schema_error(existing_table: DeltaTable): + print(existing_table._table.table_uri()) + new_data = pa.table({"utf8": pa.array([1235, 546, 5645]), "new_x": pa.array([1, 2, 3], pa.int32()), "new_y": pa.array([1, 2, 3], pa.int32())}) + + with pytest.raises(DeltaError): + write_deltalake(existing_table, new_data, mode="append", schema_mode="overwrite", engine="rust") + + + def test_update_schema_rust_writer(existing_table: DeltaTable): new_data = pa.table({"x": pa.array([1, 2, 3])}) From dfec2ac05ebd4c9939ec39ba7a2f20c2f55129c2 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Mon, 4 Mar 2024 10:22:54 +0100 Subject: [PATCH 22/40] fnt --- crates/core/src/operations/cast.rs | 16 +++++++++------- crates/core/src/operations/write.rs | 10 +++++----- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index 57ca01ba39..432e385890 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -1,23 +1,23 @@ //! Provide common cast functionality for callers //! +use arrow::datatypes::DataType::Dictionary; use arrow_array::{new_null_array, Array, ArrayRef, RecordBatch, StructArray}; use arrow_cast::{cast_with_options, CastOptions}; use arrow_schema::{ ArrowError, DataType, Field as ArrowField, Fields, Schema as ArrowSchema, - SchemaRef as ArrowSchemaRef + SchemaRef as ArrowSchemaRef, }; -use arrow::datatypes::DataType::Dictionary; use std::sync::Arc; use crate::DeltaResult; -pub (crate) fn merge_field(left: &ArrowField, right: &ArrowField) -> Result { - if let Dictionary(_, value_type) = right.data_type() { +pub(crate) fn merge_field(left: &ArrowField, right: &ArrowField) -> Result { + if let Dictionary(_, value_type) = right.data_type() { if value_type.equals_datatype(left.data_type()) { return Ok(left.clone()); } } - if let Dictionary(_, value_type) = left.data_type() { + if let Dictionary(_, value_type) = left.data_type() { if value_type.equals_datatype(right.data_type()) { return Ok(right.clone()); } @@ -30,7 +30,10 @@ pub (crate) fn merge_field(left: &ArrowField, right: &ArrowField) -> Result Result<(), ArrowError> { +pub(crate) fn is_compatible_for_merge( + schema: ArrowSchema, + other: ArrowSchema, +) -> Result<(), ArrowError> { for f in schema.fields() { if let Ok(other_field) = other.field_with_name(f.name()) { if let Err(e) = merge_field(f.as_ref(), other_field) { @@ -67,7 +70,6 @@ pub(crate) fn merge_schema( Ok(ArrowSchema::new(fields)) } - fn cast_struct( struct_array: &StructArray, fields: &Fields, diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 438c00a341..05ff9cf175 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -54,7 +54,7 @@ use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder}; use crate::errors::{DeltaResult, DeltaTableError}; use crate::kernel::{Action, Add, Metadata, PartitionsExt, Remove, StructType}; use crate::logstore::LogStoreRef; -use crate::operations::cast::{cast_record_batch, merge_schema, is_compatible_for_merge}; +use crate::operations::cast::{cast_record_batch, is_compatible_for_merge, merge_schema}; use crate::protocol::{DeltaOperation, SaveMode}; use crate::storage::ObjectStoreRef; use crate::table::state::DeltaTableState; @@ -340,7 +340,7 @@ async fn write_execution_plan_with_predicate( safe_cast: bool, schema_mode: Option, ) -> DeltaResult> { - let schema: ArrowSchemaRef = if let Some(_) = schema_mode { + let schema: ArrowSchemaRef = if let Some(_) = schema_mode { plan.schema() } else { snapshot @@ -610,7 +610,9 @@ impl std::future::IntoFuture for WriteBuilder { table_schema.as_ref().clone(), schema.as_ref().clone(), ) { - return Err(DeltaTableError::InvalidData { violations: vec!(format!("{:?}", err)) }); + return Err(DeltaTableError::InvalidData { + violations: vec![format!("{:?}", err)], + }); } new_schema = None // we overwrite anyway, so no need to cast } else if this.schema_mode == Some(SchemaMode::Merge) { @@ -1216,8 +1218,6 @@ mod tests { assert_eq!(part_cols, vec!["id", "value"]); // we want to preserve partitions } - - #[tokio::test] async fn test_overwrite_schema() { let batch = get_record_batch(None, false); From 4a0992145ebaac3b8cbb3633049b227ea6a206e7 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Mon, 4 Mar 2024 10:26:37 +0100 Subject: [PATCH 23/40] format --- python/deltalake/writer.py | 37 ++++++++--------- python/tests/test_writer.py | 81 ++++++++++++++++++++++++++++--------- 2 files changed, 81 insertions(+), 37 deletions(-) diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index 2b9381dd37..89a12e2d6e 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -31,6 +31,8 @@ else: from typing_extensions import Literal +import warnings + import pyarrow as pa import pyarrow.dataset as ds import pyarrow.fs as pa_fs @@ -49,7 +51,7 @@ convert_pyarrow_table, ) from .table import MAX_SUPPORTED_WRITER_VERSION, DeltaTable, WriterProperties -import warnings + try: import pandas as pd # noqa: F811 except ModuleNotFoundError: @@ -101,8 +103,7 @@ def write_deltalake( large_dtypes: bool = ..., engine: Literal["pyarrow"] = ..., custom_metadata: Optional[Dict[str, str]] = ..., -) -> None: - ... +) -> None: ... @overload @@ -124,14 +125,13 @@ def write_deltalake( description: Optional[str] = ..., configuration: Optional[Mapping[str, Optional[str]]] = ..., overwrite_schema: bool = ..., - schema_mode: Optional[Literal[ "merge", "overwrite"]] = ..., + schema_mode: Optional[Literal["merge", "overwrite"]] = ..., storage_options: Optional[Dict[str, str]] = ..., large_dtypes: bool = ..., engine: Literal["rust"], writer_properties: WriterProperties = ..., custom_metadata: Optional[Dict[str, str]] = ..., -) -> None: - ... +) -> None: ... @overload @@ -153,15 +153,14 @@ def write_deltalake( description: Optional[str] = ..., configuration: Optional[Mapping[str, Optional[str]]] = ..., overwrite_schema: bool = ..., - schema_mode: Optional[Literal[ "merge", "overwrite"]] = ..., + schema_mode: Optional[Literal["merge", "overwrite"]] = ..., storage_options: Optional[Dict[str, str]] = ..., predicate: Optional[str] = ..., large_dtypes: bool = ..., engine: Literal["rust"], writer_properties: WriterProperties = ..., custom_metadata: Optional[Dict[str, str]] = ..., -) -> None: - ... +) -> None: ... def write_deltalake( @@ -188,7 +187,7 @@ def write_deltalake( description: Optional[str] = None, configuration: Optional[Mapping[str, Optional[str]]] = None, overwrite_schema: bool = False, - schema_mode: Optional[Literal[ "merge", "overwrite"]] = None, + schema_mode: Optional[Literal["merge", "overwrite"]] = None, storage_options: Optional[Dict[str, str]] = None, partition_filters: Optional[List[Tuple[str, str, Any]]] = None, predicate: Optional[str] = None, @@ -263,7 +262,7 @@ def write_deltalake( __enforce_append_only(table=table, configuration=configuration, mode=mode) if overwrite_schema: schema_mode = "overwrite" - + warnings.warn( "overwrite_schema is deprecated, use schema_mode instead. ", category=DeprecationWarning, @@ -330,7 +329,9 @@ def write_deltalake( elif engine == "pyarrow": if schema_mode == "merge": - raise ValueError("schema_mode 'merge' is not supported in pyarrow engine. Use engine=rust") + raise ValueError( + "schema_mode 'merge' is not supported in pyarrow engine. Use engine=rust" + ) # We need to write against the latest table version filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options)) @@ -435,12 +436,12 @@ def check_data_is_aligned_with_partition_filtering( ) -> None: if table is None: return - existed_partitions: FrozenSet[ - FrozenSet[Tuple[str, Optional[str]]] - ] = table._table.get_active_partitions() - allowed_partitions: FrozenSet[ - FrozenSet[Tuple[str, Optional[str]]] - ] = table._table.get_active_partitions(partition_filters) + existed_partitions: FrozenSet[FrozenSet[Tuple[str, Optional[str]]]] = ( + table._table.get_active_partitions() + ) + allowed_partitions: FrozenSet[FrozenSet[Tuple[str, Optional[str]]]] = ( + table._table.get_active_partitions(partition_filters) + ) partition_values = pa.RecordBatch.from_arrays( [ batch.column(column_name) diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index c662814e14..e60219dacd 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -148,22 +148,37 @@ def test_update_schema(existing_table: DeltaTable): def test_merge_schema(existing_table: DeltaTable): print(existing_table._table.table_uri()) old_table_data = existing_table.to_pyarrow_table() - new_data = pa.table({"new_x": pa.array([1, 2, 3], pa.int32()), "new_y": pa.array([1, 2, 3], pa.int32())}) + new_data = pa.table( + { + "new_x": pa.array([1, 2, 3], pa.int32()), + "new_y": pa.array([1, 2, 3], pa.int32()), + } + ) - write_deltalake(existing_table, new_data, mode="append", schema_mode="merge", engine="rust") + write_deltalake( + existing_table, new_data, mode="append", schema_mode="merge", engine="rust" + ) # adjust schema of old_table_data and new_data to match each other - + for i in range(old_table_data.num_columns): col = old_table_data.schema.field(i) new_data = new_data.add_column(i, col, pa.nulls(new_data.num_rows, col.type)) - old_table_data=old_table_data.append_column(pa.field("new_x", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32())) - old_table_data=old_table_data.append_column(pa.field("new_y", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32())) - + old_table_data = old_table_data.append_column( + pa.field("new_x", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32()) + ) + old_table_data = old_table_data.append_column( + pa.field("new_y", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32()) + ) + # define sort order - read_data = existing_table.to_pyarrow_table().sort_by([("utf8", "ascending"), ("new_x", "ascending")]) + read_data = existing_table.to_pyarrow_table().sort_by( + [("utf8", "ascending"), ("new_x", "ascending")] + ) print(repr(read_data.to_pylist())) - concated = pa.concat_tables([old_table_data, new_data], promote_options="permissive") + concated = pa.concat_tables( + [old_table_data, new_data], promote_options="permissive" + ) print(repr(concated.to_pylist())) assert read_data == concated @@ -175,18 +190,34 @@ def test_merge_schema(existing_table: DeltaTable): def test_overwrite_schema(existing_table: DeltaTable): print(existing_table._table.table_uri()) old_table_data = existing_table.to_pyarrow_table() - new_data = pa.table({"utf8": pa.array(['bla', 'bli', 'blubb']), "new_x": pa.array([1, 2, 3], pa.int32()), "new_y": pa.array([1, 2, 3], pa.int32())}) + new_data = pa.table( + { + "utf8": pa.array(["bla", "bli", "blubb"]), + "new_x": pa.array([1, 2, 3], pa.int32()), + "new_y": pa.array([1, 2, 3], pa.int32()), + } + ) - write_deltalake(existing_table, new_data, mode="append", schema_mode="overwrite", engine="rust") + write_deltalake( + existing_table, new_data, mode="append", schema_mode="overwrite", engine="rust" + ) # adjust schema of old_table_data and new_data to match each other old_table_data = old_table_data.select(["utf8"]) - old_table_data=old_table_data.append_column(pa.field("new_x", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32())) - old_table_data=old_table_data.append_column(pa.field("new_y", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32())) - + old_table_data = old_table_data.append_column( + pa.field("new_x", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32()) + ) + old_table_data = old_table_data.append_column( + pa.field("new_y", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32()) + ) + # define sort order - read_data = existing_table.to_pyarrow_table().sort_by([("utf8", "ascending"), ("new_x", "ascending")]) + read_data = existing_table.to_pyarrow_table().sort_by( + [("utf8", "ascending"), ("new_x", "ascending")] + ) print(repr(read_data.to_pylist())) - concated = pa.concat_tables([old_table_data, new_data], promote_options="permissive") + concated = pa.concat_tables( + [old_table_data, new_data], promote_options="permissive" + ) print(repr(concated.to_pylist())) assert read_data == concated @@ -194,13 +225,25 @@ def test_overwrite_schema(existing_table: DeltaTable): assert existing_table.schema().to_pyarrow() == new_data.schema + def test_overwrite_schema_error(existing_table: DeltaTable): print(existing_table._table.table_uri()) - new_data = pa.table({"utf8": pa.array([1235, 546, 5645]), "new_x": pa.array([1, 2, 3], pa.int32()), "new_y": pa.array([1, 2, 3], pa.int32())}) - + new_data = pa.table( + { + "utf8": pa.array([1235, 546, 5645]), + "new_x": pa.array([1, 2, 3], pa.int32()), + "new_y": pa.array([1, 2, 3], pa.int32()), + } + ) + with pytest.raises(DeltaError): - write_deltalake(existing_table, new_data, mode="append", schema_mode="overwrite", engine="rust") - + write_deltalake( + existing_table, + new_data, + mode="append", + schema_mode="overwrite", + engine="rust", + ) def test_update_schema_rust_writer(existing_table: DeltaTable): From 46c084a947284c0277789740f1171588b3a4b3ac Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Mon, 4 Mar 2024 10:30:49 +0100 Subject: [PATCH 24/40] thanks, clippy for your feedback --- crates/core/src/operations/cast.rs | 9 ++------- crates/core/src/operations/write.rs | 2 +- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index 432e385890..adb32f79f9 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -23,10 +23,7 @@ pub(crate) fn merge_field(left: &ArrowField, right: &ArrowField) -> Result Result<(), ArrowError> { for f in schema.fields() { if let Ok(other_field) = other.field_with_name(f.name()) { - if let Err(e) = merge_field(f.as_ref(), other_field) { - return Err(e); - } + merge_field(f.as_ref(), other_field)?; } } Ok(()) diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 05ff9cf175..e6a2dc17ae 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -340,7 +340,7 @@ async fn write_execution_plan_with_predicate( safe_cast: bool, schema_mode: Option, ) -> DeltaResult> { - let schema: ArrowSchemaRef = if let Some(_) = schema_mode { + let schema: ArrowSchemaRef = if schema_mode.is_some() { plan.schema() } else { snapshot From e629f4c0ca01a7a0a98abeffe976c0d1e564f4c8 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Mon, 4 Mar 2024 10:48:33 +0100 Subject: [PATCH 25/40] fix ruff and mypy version and format --- docs/src/python/check_constraints.py | 4 +++- python/deltalake/table.py | 6 +++--- python/deltalake/writer.py | 21 ++++++++----------- python/docs/source/_ext/edit_on_github.py | 6 +++--- python/pyproject.toml | 4 ++-- .../test_write_to_pyspark.py | 1 + .../test_writer_readable.py | 1 + 7 files changed, 22 insertions(+), 21 deletions(-) diff --git a/docs/src/python/check_constraints.py b/docs/src/python/check_constraints.py index 16fb8bf374..1bfa62d970 100644 --- a/docs/src/python/check_constraints.py +++ b/docs/src/python/check_constraints.py @@ -13,9 +13,11 @@ def add_constraint(): def add_data(): # --8<-- [start:add_data] - from deltalake import write_deltalake + from deltalake import write_deltalake, DeltaTable import pandas as pd + dt = DeltaTable("../rust/tests/data/simple_table") + df = pd.DataFrame({"id": [-1]}) write_deltalake(dt, df, mode="append", engine="rust") # _internal.DeltaProtocolError: Invariant violations: ["Check or Invariant (id > 0) violated by value in row: [-1]"] diff --git a/python/deltalake/table.py b/python/deltalake/table.py index d80aa8632f..86fffb2a55 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -1277,9 +1277,9 @@ def __init__( self.not_matched_insert_updates: Optional[List[Dict[str, str]]] = None self.not_matched_insert_predicate: Optional[List[Optional[str]]] = None self.not_matched_by_source_update_updates: Optional[List[Dict[str, str]]] = None - self.not_matched_by_source_update_predicate: Optional[ - List[Optional[str]] - ] = None + self.not_matched_by_source_update_predicate: Optional[List[Optional[str]]] = ( + None + ) self.not_matched_by_source_delete_predicate: Optional[List[str]] = None self.not_matched_by_source_delete_all: Optional[bool] = None diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index df76ded806..6ebc496436 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -100,8 +100,7 @@ def write_deltalake( large_dtypes: bool = ..., engine: Literal["pyarrow"] = ..., custom_metadata: Optional[Dict[str, str]] = ..., -) -> None: - ... +) -> None: ... @overload @@ -128,8 +127,7 @@ def write_deltalake( engine: Literal["rust"], writer_properties: WriterProperties = ..., custom_metadata: Optional[Dict[str, str]] = ..., -) -> None: - ... +) -> None: ... @overload @@ -157,8 +155,7 @@ def write_deltalake( engine: Literal["rust"], writer_properties: WriterProperties = ..., custom_metadata: Optional[Dict[str, str]] = ..., -) -> None: - ... +) -> None: ... def write_deltalake( @@ -421,12 +418,12 @@ def check_data_is_aligned_with_partition_filtering( ) -> None: if table is None: return - existed_partitions: FrozenSet[ - FrozenSet[Tuple[str, Optional[str]]] - ] = table._table.get_active_partitions() - allowed_partitions: FrozenSet[ - FrozenSet[Tuple[str, Optional[str]]] - ] = table._table.get_active_partitions(partition_filters) + existed_partitions: FrozenSet[FrozenSet[Tuple[str, Optional[str]]]] = ( + table._table.get_active_partitions() + ) + allowed_partitions: FrozenSet[FrozenSet[Tuple[str, Optional[str]]]] = ( + table._table.get_active_partitions(partition_filters) + ) partition_values = pa.RecordBatch.from_arrays( [ batch.column(column_name) diff --git a/python/docs/source/_ext/edit_on_github.py b/python/docs/source/_ext/edit_on_github.py index f7188f189a..241560877c 100644 --- a/python/docs/source/_ext/edit_on_github.py +++ b/python/docs/source/_ext/edit_on_github.py @@ -38,9 +38,9 @@ def html_page_context(app, pagename, templatename, context, doctree): context["display_github"] = True context["github_user"] = app.config.edit_on_github_project.split("/")[0] context["github_repo"] = app.config.edit_on_github_project.split("/")[1] - context[ - "github_version" - ] = f"{app.config.edit_on_github_branch}/{app.config.page_source_prefix}/" + context["github_version"] = ( + f"{app.config.edit_on_github_branch}/{app.config.page_source_prefix}/" + ) def setup(app): diff --git a/python/pyproject.toml b/python/pyproject.toml index e9fc7389af..9b74760948 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -27,8 +27,8 @@ pandas = [ "pandas" ] devel = [ - "mypy", - "ruff>=0.1.5", + "mypy~=1.8.0", + "ruff~=0.3.0", "packaging>=20", "pytest", "pytest-mock", diff --git a/python/tests/pyspark_integration/test_write_to_pyspark.py b/python/tests/pyspark_integration/test_write_to_pyspark.py index 8418f587ca..5cf6490a62 100644 --- a/python/tests/pyspark_integration/test_write_to_pyspark.py +++ b/python/tests/pyspark_integration/test_write_to_pyspark.py @@ -1,4 +1,5 @@ """Tests that deltalake(delta-rs) can write to tables written by PySpark.""" + import pathlib import pyarrow as pa diff --git a/python/tests/pyspark_integration/test_writer_readable.py b/python/tests/pyspark_integration/test_writer_readable.py index ea555074b8..3ade57c6e9 100644 --- a/python/tests/pyspark_integration/test_writer_readable.py +++ b/python/tests/pyspark_integration/test_writer_readable.py @@ -1,4 +1,5 @@ """Test that pyspark can read tables written by deltalake(delta-rs).""" + import pathlib import pyarrow as pa From d14b4b0447f2a1515cdfe2069fc0f68c4b2e0108 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Mon, 4 Mar 2024 13:22:01 +0100 Subject: [PATCH 26/40] validate schema if schema_mode not given --- crates/core/src/operations/write.rs | 49 ++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index e6a2dc17ae..aa2f2cd811 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -603,7 +603,7 @@ impl std::future::IntoFuture for WriteBuilder { .unwrap_or(schema.clone()); if !can_cast_batch(schema.fields(), table_schema.fields()) { - if this.mode == SaveMode::Overwrite { + if this.mode == SaveMode::Overwrite && this.schema_mode.is_some() { new_schema = None // we overwrite anyway, so no need to cast } else if this.schema_mode == Some(SchemaMode::Overwrite) { if let Err(err) = is_compatible_for_merge( @@ -621,6 +621,7 @@ impl std::future::IntoFuture for WriteBuilder { schema.as_ref().clone(), )?)); } else { + // this is a feature! Unless you specify a schema_mode explicity, we want to check the schema! return Err(DeltaTableError::Generic( "Schema of data does not match table schema".to_string(), )); @@ -1276,6 +1277,52 @@ mod tests { assert_eq!(names, vec!["id", "value", "inserted_by"]); } + #[tokio::test] + async fn test_overwrite_check() { + // If you do not pass a schema mode, we want to check the schema + let batch = get_record_batch(None, false); + let table = DeltaOps::new_in_memory() + .write(vec![batch.clone()]) + .with_save_mode(SaveMode::ErrorIfExists) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); + + new_schema_builder.push(Field::new("inserted_by", DataType::Utf8, true)); + let new_schema = new_schema_builder.finish(); + let new_fields = new_schema.fields(); + let new_names = new_fields.iter().map(|f| f.name()).collect::>(); + assert_eq!(new_names, vec!["inserted_by"]); + let inserted_by = StringArray::from(vec![ + Some("A1"), + Some("B1"), + None, + Some("B2"), + Some("A3"), + Some("A4"), + None, + None, + Some("B4"), + Some("A5"), + Some("A7"), + ]); + let new_batch = RecordBatch::try_new( + Arc::new(new_schema), + vec![ + Arc::new(inserted_by), + ], + ) + .unwrap(); + + let table = DeltaOps(table) + .write(vec![new_batch]) + .with_save_mode(SaveMode::Append) + .await; + assert!(table.is_err()); + } + #[tokio::test] async fn test_check_invariants() { let batch = get_record_batch(None, false); From dc7177135bbcd150c1493c6661f6894859ae9168 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Mon, 4 Mar 2024 13:22:20 +0100 Subject: [PATCH 27/40] use new schema_mode parameter and refactor tests to match new behavior --- python/tests/test_writer.py | 73 ++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 30 deletions(-) diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index e60219dacd..38bedc57ff 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -136,9 +136,9 @@ def test_update_schema(existing_table: DeltaTable): new_data = pa.table({"x": pa.array([1, 2, 3])}) with pytest.raises(ValueError): - write_deltalake(existing_table, new_data, mode="append", overwrite_schema=True) + write_deltalake(existing_table, new_data, mode="append", schema_mode="overwrite") - write_deltalake(existing_table, new_data, mode="overwrite", overwrite_schema=True) + write_deltalake(existing_table, new_data, mode="overwrite", schema_mode="overwrite") read_data = existing_table.to_pyarrow_table() assert new_data == read_data @@ -182,7 +182,7 @@ def test_merge_schema(existing_table: DeltaTable): print(repr(concated.to_pylist())) assert read_data == concated - write_deltalake(existing_table, new_data, mode="overwrite", overwrite_schema=True) + write_deltalake(existing_table, new_data, mode="overwrite", schema_mode="overwrite") assert existing_table.schema().to_pyarrow() == new_data.schema @@ -221,7 +221,7 @@ def test_overwrite_schema(existing_table: DeltaTable): print(repr(concated.to_pylist())) assert read_data == concated - write_deltalake(existing_table, new_data, mode="overwrite", overwrite_schema=True) + write_deltalake(existing_table, new_data, mode="overwrite", schema_mode="overwrite") assert existing_table.schema().to_pyarrow() == new_data.schema @@ -245,39 +245,47 @@ def test_overwrite_schema_error(existing_table: DeltaTable): engine="rust", ) - -def test_update_schema_rust_writer(existing_table: DeltaTable): - new_data = pa.table({"x": pa.array([1, 2, 3])}) - +def test_update_schema_rust_writer_append(existing_table: DeltaTable): with pytest.raises(DeltaError): + # It's illegal to do schema drift without correct schema_mode write_deltalake( existing_table, - new_data, + pa.table({"x4": pa.array([1, 2, 3])}), mode="append", - overwrite_schema=True, + schema_mode=None, engine="rust", ) + write_deltalake( + existing_table, + pa.table({"x1": pa.array([1, 2, 3])}), + mode="append", + schema_mode="overwrite", + engine="rust", + ) + write_deltalake( + existing_table, + pa.table({"x2": pa.array([1, 2, 3])}), + mode="append", + schema_mode="merge", + engine="rust", + ) + +def test_update_schema_rust_writer_invalid(existing_table: DeltaTable): + new_data = pa.table({"x5": pa.array([1, 2, 3])}) with pytest.raises(DeltaError): write_deltalake( existing_table, new_data, mode="overwrite", - overwrite_schema=False, - engine="rust", - ) - with pytest.raises(DeltaError): - write_deltalake( - existing_table, - new_data, - mode="append", - overwrite_schema=False, + schema_mode=None, engine="rust", ) + write_deltalake( existing_table, new_data, mode="overwrite", - overwrite_schema=True, + schema_mode="overwrite", engine="rust", ) @@ -761,36 +769,41 @@ def test_writer_with_options(tmp_path: pathlib.Path): def test_try_get_table_and_table_uri(tmp_path: pathlib.Path): + from typing import TypeVar + T = TypeVar("T") + def _normalize_path(t: tuple[T, str]): # who does not love Windows? ;) + return t[0], t[1].replace("\\", "/") if t[1] else t[1] + data = pa.table({"vals": pa.array(["1", "2", "3"])}) table_or_uri = tmp_path / "delta_table" write_deltalake(table_or_uri, data) delta_table = DeltaTable(table_or_uri) # table_or_uri as DeltaTable - assert try_get_table_and_table_uri(delta_table, None) == ( + assert _normalize_path(try_get_table_and_table_uri(delta_table, None)) == _normalize_path(( delta_table, str(tmp_path / "delta_table") + "/", - ) + )) # table_or_uri as str - assert try_get_table_and_table_uri(str(tmp_path / "delta_table"), None) == ( + assert _normalize_path(try_get_table_and_table_uri(str(tmp_path / "delta_table"), None)) == _normalize_path(( delta_table, str(tmp_path / "delta_table"), - ) - assert try_get_table_and_table_uri(str(tmp_path / "str"), None) == ( + )) + assert _normalize_path(try_get_table_and_table_uri(str(tmp_path / "str"), None)) == _normalize_path(( None, str(tmp_path / "str"), - ) + )) # table_or_uri as Path - assert try_get_table_and_table_uri(tmp_path / "delta_table", None) == ( + assert _normalize_path(try_get_table_and_table_uri(tmp_path / "delta_table", None)) == _normalize_path(( delta_table, str(tmp_path / "delta_table"), - ) - assert try_get_table_and_table_uri(tmp_path / "Path", None) == ( + )) + assert _normalize_path(try_get_table_and_table_uri(tmp_path / "Path", None)) == _normalize_path(( None, str(tmp_path / "Path"), - ) + )) # table_or_uri with invalid parameter type with pytest.raises(ValueError): From 4c7a9e1292de9fdbc7264627b6c12f33788e4a32 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Mon, 4 Mar 2024 13:28:43 +0100 Subject: [PATCH 28/40] docs --- docs/integrations/delta-lake-pandas.md | 6 ++++-- docs/usage/writing/index.md | 4 +++- python/docs/source/usage.rst | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/integrations/delta-lake-pandas.md b/docs/integrations/delta-lake-pandas.md index b14c1bd45b..ca60362838 100644 --- a/docs/integrations/delta-lake-pandas.md +++ b/docs/integrations/delta-lake-pandas.md @@ -250,10 +250,10 @@ Schema enforcement protects your table from getting corrupted by appending data ## Overwriting schema of table -You can overwrite the table contents and schema by setting the `overwrite_schema` option. Here's how to overwrite the table contents: +You can overwrite the table contents and schema by setting the `schema_mode` option. Here's how to overwrite the table contents: ```python -write_deltalake("tmp/some-table", df, mode="overwrite", overwrite_schema=True) +write_deltalake("tmp/some-table", df, mode="overwrite", schema_mode="overwrite") ``` Here are the contents of the table after the values and schema have been overwritten: @@ -267,6 +267,8 @@ Here are the contents of the table after the values and schema have been overwri +-------+----------+ ``` +If you want the schema to be merged instead, specify schema_mode="merge". + ## In-memory vs. in-storage data changes It's important to distinguish between data stored in-memory and data stored on disk when understanding the functionality offered by Delta Lake. diff --git a/docs/usage/writing/index.md b/docs/usage/writing/index.md index dc8bb62389..9c8e4f08e0 100644 --- a/docs/usage/writing/index.md +++ b/docs/usage/writing/index.md @@ -23,7 +23,9 @@ of Spark's `pyspark.sql.DataFrameWriter.saveAsTable` DataFrame method. To overwr `write_deltalake` will raise `ValueError` if the schema of the data passed to it differs from the existing table's schema. If you wish to -alter the schema as part of an overwrite pass in `overwrite_schema=True`. +alter the schema as part of an overwrite pass in `schema_mode="overwrite"` or `schema_mode="merge"`. +`schema_mode="overwrite"` will completely overwrite the schema, even if columns are dropped; merge will append the new columns +and fill missing columns with `null`. ## Overwriting a partition diff --git a/python/docs/source/usage.rst b/python/docs/source/usage.rst index d0349a450c..baa26f275c 100644 --- a/python/docs/source/usage.rst +++ b/python/docs/source/usage.rst @@ -481,7 +481,7 @@ to append pass in ``mode='append'``: :py:meth:`write_deltalake` will raise :py:exc:`ValueError` if the schema of the data passed to it differs from the existing table's schema. If you wish to -alter the schema as part of an overwrite pass in ``overwrite_schema=True``. +alter the schema as part of an overwrite pass in ``schema_mode="overwrite"``. Writing to s3 ~~~~~~~~~~~~~ From 9fbb9bba7af218a3c05487db6f3496ff521e0352 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Mon, 4 Mar 2024 13:29:19 +0100 Subject: [PATCH 29/40] fmt --- crates/core/src/operations/write.rs | 11 ++--- python/tests/test_writer.py | 72 ++++++++++++++++++++--------- 2 files changed, 52 insertions(+), 31 deletions(-) diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index aa2f2cd811..f5f12bc8f2 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -1289,7 +1289,7 @@ mod tests { assert_eq!(table.version(), 0); let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); - + new_schema_builder.push(Field::new("inserted_by", DataType::Utf8, true)); let new_schema = new_schema_builder.finish(); let new_fields = new_schema.fields(); @@ -1308,13 +1308,8 @@ mod tests { Some("A5"), Some("A7"), ]); - let new_batch = RecordBatch::try_new( - Arc::new(new_schema), - vec![ - Arc::new(inserted_by), - ], - ) - .unwrap(); + let new_batch = + RecordBatch::try_new(Arc::new(new_schema), vec![Arc::new(inserted_by)]).unwrap(); let table = DeltaOps(table) .write(vec![new_batch]) diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 38bedc57ff..001036be8f 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -136,7 +136,9 @@ def test_update_schema(existing_table: DeltaTable): new_data = pa.table({"x": pa.array([1, 2, 3])}) with pytest.raises(ValueError): - write_deltalake(existing_table, new_data, mode="append", schema_mode="overwrite") + write_deltalake( + existing_table, new_data, mode="append", schema_mode="overwrite" + ) write_deltalake(existing_table, new_data, mode="overwrite", schema_mode="overwrite") @@ -245,6 +247,7 @@ def test_overwrite_schema_error(existing_table: DeltaTable): engine="rust", ) + def test_update_schema_rust_writer_append(existing_table: DeltaTable): with pytest.raises(DeltaError): # It's illegal to do schema drift without correct schema_mode @@ -270,6 +273,7 @@ def test_update_schema_rust_writer_append(existing_table: DeltaTable): engine="rust", ) + def test_update_schema_rust_writer_invalid(existing_table: DeltaTable): new_data = pa.table({"x5": pa.array([1, 2, 3])}) with pytest.raises(DeltaError): @@ -280,7 +284,7 @@ def test_update_schema_rust_writer_invalid(existing_table: DeltaTable): schema_mode=None, engine="rust", ) - + write_deltalake( existing_table, new_data, @@ -770,8 +774,10 @@ def test_writer_with_options(tmp_path: pathlib.Path): def test_try_get_table_and_table_uri(tmp_path: pathlib.Path): from typing import TypeVar + T = TypeVar("T") - def _normalize_path(t: tuple[T, str]): # who does not love Windows? ;) + + def _normalize_path(t: tuple[T, str]): # who does not love Windows? ;) return t[0], t[1].replace("\\", "/") if t[1] else t[1] data = pa.table({"vals": pa.array(["1", "2", "3"])}) @@ -780,30 +786,50 @@ def _normalize_path(t: tuple[T, str]): # who does not love Windows? ;) delta_table = DeltaTable(table_or_uri) # table_or_uri as DeltaTable - assert _normalize_path(try_get_table_and_table_uri(delta_table, None)) == _normalize_path(( - delta_table, - str(tmp_path / "delta_table") + "/", - )) + assert _normalize_path( + try_get_table_and_table_uri(delta_table, None) + ) == _normalize_path( + ( + delta_table, + str(tmp_path / "delta_table") + "/", + ) + ) # table_or_uri as str - assert _normalize_path(try_get_table_and_table_uri(str(tmp_path / "delta_table"), None)) == _normalize_path(( - delta_table, - str(tmp_path / "delta_table"), - )) - assert _normalize_path(try_get_table_and_table_uri(str(tmp_path / "str"), None)) == _normalize_path(( - None, - str(tmp_path / "str"), - )) + assert _normalize_path( + try_get_table_and_table_uri(str(tmp_path / "delta_table"), None) + ) == _normalize_path( + ( + delta_table, + str(tmp_path / "delta_table"), + ) + ) + assert _normalize_path( + try_get_table_and_table_uri(str(tmp_path / "str"), None) + ) == _normalize_path( + ( + None, + str(tmp_path / "str"), + ) + ) # table_or_uri as Path - assert _normalize_path(try_get_table_and_table_uri(tmp_path / "delta_table", None)) == _normalize_path(( - delta_table, - str(tmp_path / "delta_table"), - )) - assert _normalize_path(try_get_table_and_table_uri(tmp_path / "Path", None)) == _normalize_path(( - None, - str(tmp_path / "Path"), - )) + assert _normalize_path( + try_get_table_and_table_uri(tmp_path / "delta_table", None) + ) == _normalize_path( + ( + delta_table, + str(tmp_path / "delta_table"), + ) + ) + assert _normalize_path( + try_get_table_and_table_uri(tmp_path / "Path", None) + ) == _normalize_path( + ( + None, + str(tmp_path / "Path"), + ) + ) # table_or_uri with invalid parameter type with pytest.raises(ValueError): From d70b716521f31f5684c94cdcd834ad77e02070e8 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Mon, 4 Mar 2024 13:46:33 +0100 Subject: [PATCH 30/40] remove parameter that causes trouble with pyarrow 8 --- python/tests/test_writer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 001036be8f..94b4def8ac 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -179,7 +179,7 @@ def test_merge_schema(existing_table: DeltaTable): ) print(repr(read_data.to_pylist())) concated = pa.concat_tables( - [old_table_data, new_data], promote_options="permissive" + [old_table_data, new_data] ) print(repr(concated.to_pylist())) assert read_data == concated @@ -218,7 +218,7 @@ def test_overwrite_schema(existing_table: DeltaTable): ) print(repr(read_data.to_pylist())) concated = pa.concat_tables( - [old_table_data, new_data], promote_options="permissive" + [old_table_data, new_data] ) print(repr(concated.to_pylist())) assert read_data == concated From e816061ed5d4e380c4f80aa6060b5895c0de9557 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Mon, 4 Mar 2024 13:49:16 +0100 Subject: [PATCH 31/40] format again :) --- python/tests/test_writer.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 94b4def8ac..8fceab8598 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -178,9 +178,7 @@ def test_merge_schema(existing_table: DeltaTable): [("utf8", "ascending"), ("new_x", "ascending")] ) print(repr(read_data.to_pylist())) - concated = pa.concat_tables( - [old_table_data, new_data] - ) + concated = pa.concat_tables([old_table_data, new_data]) print(repr(concated.to_pylist())) assert read_data == concated @@ -217,9 +215,7 @@ def test_overwrite_schema(existing_table: DeltaTable): [("utf8", "ascending"), ("new_x", "ascending")] ) print(repr(read_data.to_pylist())) - concated = pa.concat_tables( - [old_table_data, new_data] - ) + concated = pa.concat_tables([old_table_data, new_data]) print(repr(concated.to_pylist())) assert read_data == concated From b07f2190115c8acf99a28ebf57e00834414c0d83 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Mon, 4 Mar 2024 14:06:44 +0100 Subject: [PATCH 32/40] fighting with py 3.8 ;) --- python/tests/test_writer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 8fceab8598..0436214abb 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -769,11 +769,7 @@ def test_writer_with_options(tmp_path: pathlib.Path): def test_try_get_table_and_table_uri(tmp_path: pathlib.Path): - from typing import TypeVar - - T = TypeVar("T") - - def _normalize_path(t: tuple[T, str]): # who does not love Windows? ;) + def _normalize_path(t): # who does not love Windows? ;) return t[0], t[1].replace("\\", "/") if t[1] else t[1] data = pa.table({"vals": pa.array(["1", "2", "3"])}) From a7ee4630bfa4e13c763b87aec17c34b8c9e34599 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Mon, 4 Mar 2024 21:21:44 +0100 Subject: [PATCH 33/40] address feedback --- crates/core/src/operations/cast.rs | 44 +++++----- crates/core/src/operations/write.rs | 131 +++++++++++++++------------- docs/usage/writing/index.md | 2 +- python/src/error.rs | 1 + python/tests/test_writer.py | 71 +++++++-------- 5 files changed, 126 insertions(+), 123 deletions(-) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index adb32f79f9..96e4f4a849 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -7,6 +7,7 @@ use arrow_schema::{ ArrowError, DataType, Field as ArrowField, Fields, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, }; +use futures::SinkExt; use std::sync::Arc; use crate::DeltaResult; @@ -27,42 +28,45 @@ pub(crate) fn merge_field(left: &ArrowField, right: &ArrowField) -> Result Result<(), ArrowError> { - for f in schema.fields() { - if let Ok(other_field) = other.field_with_name(f.name()) { - merge_field(f.as_ref(), other_field)?; - } - } - Ok(()) -} - pub(crate) fn merge_schema( left: ArrowSchema, right: ArrowSchema, ) -> Result { - let left_fields: Result, ArrowError> = left + let mut errors = Vec::with_capacity(left.fields().len()); + let merged_fields: Result, ArrowError> = left .fields() .iter() .map(|field| { let right_field = right.field_with_name(field.name()); if let Ok(right_field) = right_field { - merge_field(field.as_ref(), right_field) + let field_or_not = merge_field(field.as_ref(), right_field); + match field_or_not { + Err(e) => { + errors.push(e.to_string()); + Err(e) + } + Ok(f) => Ok(f), + } } else { Ok(field.as_ref().clone()) } }) .collect(); - let mut fields = left_fields?; - for field in right.fields() { - if !left.field_with_name(field.name()).is_ok() { - fields.push(field.as_ref().clone()); + match merged_fields { + Ok(mut fields) => { + for field in right.fields() { + if !left.field_with_name(field.name()).is_ok() { + fields.push(field.as_ref().clone()); + } + } + + Ok(ArrowSchema::new(fields)) + } + Err(e) => { + errors.push(e.to_string()); + Err(ArrowError::SchemaError(errors.join("\n"))) } } - - Ok(ArrowSchema::new(fields)) } fn cast_struct( diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index f5f12bc8f2..dc6e2aa9aa 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -32,7 +32,7 @@ use std::vec; use arrow_array::RecordBatch; use arrow_cast::can_cast_types; -use arrow_schema::{DataType, Fields, SchemaRef as ArrowSchemaRef}; +use arrow_schema::{ArrowError, DataType, Fields, SchemaRef as ArrowSchemaRef}; use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion::physical_expr::create_physical_expr; use datafusion::physical_plan::filter::FilterExec; @@ -54,7 +54,7 @@ use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder}; use crate::errors::{DeltaResult, DeltaTableError}; use crate::kernel::{Action, Add, Metadata, PartitionsExt, Remove, StructType}; use crate::logstore::LogStoreRef; -use crate::operations::cast::{cast_record_batch, is_compatible_for_merge, merge_schema}; +use crate::operations::cast::{cast_record_batch, merge_schema}; use crate::protocol::{DeltaOperation, SaveMode}; use crate::storage::ObjectStoreRef; use crate::table::state::DeltaTableState; @@ -107,7 +107,7 @@ impl FromStr for SchemaMode { "overwrite" => Ok(SchemaMode::Overwrite), "merge" => Ok(SchemaMode::Merge), _ => Err(DeltaTableError::Generic(format!( - "Invalid schema write mode provided: {}, only these are supported: ['none', 'overwrite', 'merge']", + "Invalid schema write mode provided: {}, only these are supported: ['overwrite', 'merge']", s ))), } @@ -554,6 +554,11 @@ impl std::future::IntoFuture for WriteBuilder { PROTOCOL.check_append_only(snapshot)?; } } + if this.schema_mode == Some(SchemaMode::Overwrite) && this.mode != SaveMode::Overwrite { + return Err(DeltaTableError::Generic( + "Schema overwrite not supported for Append".to_string(), + )); + } // Create table actions to initialize table in case it does not yet exist and should be created let mut actions = this.check_preconditions().await?; @@ -580,7 +585,7 @@ impl std::future::IntoFuture for WriteBuilder { } else { Ok(this.partition_columns.unwrap_or_default()) }?; - + let mut schema_drift = false; let plan = if let Some(plan) = this.input { if this.schema_mode == Some(SchemaMode::Merge) { return Err(DeltaTableError::Generic( @@ -602,29 +607,27 @@ impl std::future::IntoFuture for WriteBuilder { .or_else(|_| snapshot.arrow_schema()) .unwrap_or(schema.clone()); - if !can_cast_batch(schema.fields(), table_schema.fields()) { + if let Err(schema_err) = + try_cast_batch(schema.fields(), table_schema.fields()) + { + schema_drift = true; if this.mode == SaveMode::Overwrite && this.schema_mode.is_some() { new_schema = None // we overwrite anyway, so no need to cast - } else if this.schema_mode == Some(SchemaMode::Overwrite) { - if let Err(err) = is_compatible_for_merge( - table_schema.as_ref().clone(), - schema.as_ref().clone(), - ) { - return Err(DeltaTableError::InvalidData { - violations: vec![format!("{:?}", err)], - }); - } - new_schema = None // we overwrite anyway, so no need to cast } else if this.schema_mode == Some(SchemaMode::Merge) { - new_schema = Some(Arc::new(merge_schema( - table_schema.as_ref().clone(), - schema.as_ref().clone(), - )?)); - } else { - // this is a feature! Unless you specify a schema_mode explicity, we want to check the schema! - return Err(DeltaTableError::Generic( - "Schema of data does not match table schema".to_string(), + new_schema = Some(Arc::new( + merge_schema( + table_schema.as_ref().clone(), + schema.as_ref().clone(), + ) + .map_err(|e| { + DeltaTableError::Generic(format!( + "Error merging schema {:?}", + e + )) + })?, )); + } else { + return Err(schema_err.into()); } } } @@ -686,25 +689,15 @@ impl std::future::IntoFuture for WriteBuilder { Err(WriteError::MissingData) }?; let schema = plan.schema(); - if this.schema_mode == Some(SchemaMode::Merge) - || (this.schema_mode == Some(SchemaMode::Overwrite) - && this.mode != SaveMode::Overwrite) - { + if this.schema_mode == Some(SchemaMode::Merge) && schema_drift { if let Some(snapshot) = &this.snapshot { - let table_schema = snapshot - .physical_arrow_schema(this.log_store.object_store().clone()) - .await - .or_else(|_| snapshot.arrow_schema()) - .unwrap_or(schema.clone()); - if !can_cast_batch(schema.fields(), table_schema.fields()) { - let schema_struct: StructType = schema.clone().try_into()?; - let schema_action = Action::Metadata(Metadata::try_new( - schema_struct, - partition_columns.clone(), - snapshot.metadata().configuration.clone(), - )?); - actions.push(schema_action); - } + let schema_struct: StructType = schema.clone().try_into()?; + let schema_action = Action::Metadata(Metadata::try_new( + schema_struct, + partition_columns.clone(), + snapshot.metadata().configuration.clone(), + )?); + actions.push(schema_action); } } let state = match this.state { @@ -831,24 +824,44 @@ impl std::future::IntoFuture for WriteBuilder { } } -fn can_cast_batch(from_fields: &Fields, to_fields: &Fields) -> bool { +fn try_cast_batch(from_fields: &Fields, to_fields: &Fields) -> Result<(), ArrowError> { if from_fields.len() != to_fields.len() { - return false; + return Err(ArrowError::SchemaError(format!( + "Cannot schema, number of fields does not match: {} vs {}", + from_fields.len(), + to_fields.len() + ))); } - from_fields.iter().all(|f| { - if let Some((_, target_field)) = to_fields.find(f.name()) { - if let (DataType::Struct(fields0), DataType::Struct(fields1)) = - (f.data_type(), target_field.data_type()) - { - can_cast_batch(fields0, fields1) + from_fields + .iter() + .map(|f| { + if let Some((_, target_field)) = to_fields.find(f.name()) { + if let (DataType::Struct(fields0), DataType::Struct(fields1)) = + (f.data_type(), target_field.data_type()) + { + try_cast_batch(fields0, fields1) + } else { + if !can_cast_types(f.data_type(), target_field.data_type()) { + Err(ArrowError::SchemaError(format!( + "Cannot cast field {} from {} to {}", + f.name(), + f.data_type(), + target_field.data_type() + ))) + } else { + Ok(()) + } + } } else { - can_cast_types(f.data_type(), target_field.data_type()) + Err(ArrowError::SchemaError(format!( + "Field {} not found in schema", + f.name() + ))) } - } else { - false - } - }) + }) + .collect::, _>>()?; + Ok(()) } #[cfg(test)] @@ -1267,14 +1280,8 @@ mod tests { .write(vec![new_batch]) .with_save_mode(SaveMode::Append) .with_schema_mode(SchemaMode::Overwrite) - .await - .unwrap(); - - assert_eq!(table.version(), 1); - let new_schema = table.metadata().unwrap().schema().unwrap(); - let fields = new_schema.fields(); - let names = fields.iter().map(|f| f.name()).collect::>(); - assert_eq!(names, vec!["id", "value", "inserted_by"]); + .await; + assert!(table.is_err()); } #[tokio::test] diff --git a/docs/usage/writing/index.md b/docs/usage/writing/index.md index 9c8e4f08e0..9e9e1bcbec 100644 --- a/docs/usage/writing/index.md +++ b/docs/usage/writing/index.md @@ -25,7 +25,7 @@ of Spark's `pyspark.sql.DataFrameWriter.saveAsTable` DataFrame method. To overwr passed to it differs from the existing table's schema. If you wish to alter the schema as part of an overwrite pass in `schema_mode="overwrite"` or `schema_mode="merge"`. `schema_mode="overwrite"` will completely overwrite the schema, even if columns are dropped; merge will append the new columns -and fill missing columns with `null`. +and fill missing columns with `null`. `schema_mode="merge"` is also supported on append operations. ## Overwriting a partition diff --git a/python/src/error.rs b/python/src/error.rs index a69160e3ec..241d0702eb 100644 --- a/python/src/error.rs +++ b/python/src/error.rs @@ -55,6 +55,7 @@ fn arrow_to_py(err: ArrowError) -> PyErr { ArrowError::DivideByZero => PyValueError::new_err("division by zero"), ArrowError::InvalidArgumentError(msg) => PyValueError::new_err(msg), ArrowError::NotYetImplemented(msg) => PyNotImplementedError::new_err(msg), + ArrowError::SchemaError(msg) => PyValueError::new_err(msg), other => PyException::new_err(other.to_string()), } } diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 0436214abb..4e948c341b 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -188,52 +188,30 @@ def test_merge_schema(existing_table: DeltaTable): def test_overwrite_schema(existing_table: DeltaTable): - print(existing_table._table.table_uri()) - old_table_data = existing_table.to_pyarrow_table() - new_data = pa.table( + new_data_invalid = pa.table( { - "utf8": pa.array(["bla", "bli", "blubb"]), + "utf8": pa.array([1235, 546, 5645]), "new_x": pa.array([1, 2, 3], pa.int32()), "new_y": pa.array([1, 2, 3], pa.int32()), } ) - write_deltalake( - existing_table, new_data, mode="append", schema_mode="overwrite", engine="rust" - ) - # adjust schema of old_table_data and new_data to match each other - old_table_data = old_table_data.select(["utf8"]) - old_table_data = old_table_data.append_column( - pa.field("new_x", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32()) - ) - old_table_data = old_table_data.append_column( - pa.field("new_y", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32()) - ) - - # define sort order - read_data = existing_table.to_pyarrow_table().sort_by( - [("utf8", "ascending"), ("new_x", "ascending")] - ) - print(repr(read_data.to_pylist())) - concated = pa.concat_tables([old_table_data, new_data]) - print(repr(concated.to_pylist())) - assert read_data == concated - - write_deltalake(existing_table, new_data, mode="overwrite", schema_mode="overwrite") - - assert existing_table.schema().to_pyarrow() == new_data.schema - + with pytest.raises(DeltaError): + write_deltalake( + existing_table, + new_data_invalid, + mode="append", + schema_mode="overwrite", + engine="rust", + ) -def test_overwrite_schema_error(existing_table: DeltaTable): - print(existing_table._table.table_uri()) new_data = pa.table( { - "utf8": pa.array([1235, 546, 5645]), + "utf8": pa.array(["bla", "bli", "blubb"]), "new_x": pa.array([1, 2, 3], pa.int32()), "new_y": pa.array([1, 2, 3], pa.int32()), } ) - with pytest.raises(DeltaError): write_deltalake( existing_table, @@ -243,6 +221,10 @@ def test_overwrite_schema_error(existing_table: DeltaTable): engine="rust", ) + write_deltalake(existing_table, new_data, mode="overwrite", schema_mode="overwrite") + + assert existing_table.schema().to_pyarrow() == new_data.schema + def test_update_schema_rust_writer_append(existing_table: DeltaTable): with pytest.raises(DeltaError): @@ -254,13 +236,22 @@ def test_update_schema_rust_writer_append(existing_table: DeltaTable): schema_mode=None, engine="rust", ) - write_deltalake( - existing_table, - pa.table({"x1": pa.array([1, 2, 3])}), - mode="append", - schema_mode="overwrite", - engine="rust", - ) + with pytest.raises(DeltaError): + write_deltalake( + existing_table, + pa.table({"x1": pa.array([1, 2, 3])}), + mode="append", + schema_mode="overwrite", + engine="rust", + ) + with pytest.raises(DeltaError): + write_deltalake( + existing_table, + pa.table({"utf8": pa.array([1, 2, 3])}), + mode="append", + schema_mode="merge", + engine="rust", + ) write_deltalake( existing_table, pa.table({"x2": pa.array([1, 2, 3])}), From 6a9012bddfa4c9d4f9170e1b774e114fe5fe1d34 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Mon, 4 Mar 2024 21:32:31 +0100 Subject: [PATCH 34/40] clippy --- crates/core/src/operations/write.rs | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index dc6e2aa9aa..2bbd220a17 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -841,17 +841,15 @@ fn try_cast_batch(from_fields: &Fields, to_fields: &Fields) -> Result<(), ArrowE (f.data_type(), target_field.data_type()) { try_cast_batch(fields0, fields1) + } else if !can_cast_types(f.data_type(), target_field.data_type()) { + Err(ArrowError::SchemaError(format!( + "Cannot cast field {} from {} to {}", + f.name(), + f.data_type(), + target_field.data_type() + ))) } else { - if !can_cast_types(f.data_type(), target_field.data_type()) { - Err(ArrowError::SchemaError(format!( - "Cannot cast field {} from {} to {}", - f.name(), - f.data_type(), - target_field.data_type() - ))) - } else { - Ok(()) - } + Ok(()) } } else { Err(ArrowError::SchemaError(format!( From 06eb8b3b2c85e2d402f06e7b7a5156080c81eda8 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Tue, 5 Mar 2024 06:23:40 +0100 Subject: [PATCH 35/40] errors --- crates/core/src/operations/write.rs | 10 ++-------- python/tests/test_writer.py | 10 +++++----- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 2bbd220a17..2d3e7fa839 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -618,13 +618,7 @@ impl std::future::IntoFuture for WriteBuilder { merge_schema( table_schema.as_ref().clone(), schema.as_ref().clone(), - ) - .map_err(|e| { - DeltaTableError::Generic(format!( - "Error merging schema {:?}", - e - )) - })?, + )?, )); } else { return Err(schema_err.into()); @@ -827,7 +821,7 @@ impl std::future::IntoFuture for WriteBuilder { fn try_cast_batch(from_fields: &Fields, to_fields: &Fields) -> Result<(), ArrowError> { if from_fields.len() != to_fields.len() { return Err(ArrowError::SchemaError(format!( - "Cannot schema, number of fields does not match: {} vs {}", + "Cannot cast schema, number of fields does not match: {} vs {}", from_fields.len(), to_fields.len() ))); diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 4e948c341b..417ccfbd18 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -124,11 +124,11 @@ def test_enforce_schema(existing_table: DeltaTable, mode: str): def test_enforce_schema_rust_writer(existing_table: DeltaTable, mode: str): bad_data = pa.table({"x": pa.array([1, 2, 3])}) - with pytest.raises(DeltaError): + with pytest.raises(ValueError): write_deltalake(existing_table, bad_data, mode=mode, engine="rust") table_uri = existing_table._table.table_uri() - with pytest.raises(DeltaError): + with pytest.raises(ValueError): write_deltalake(table_uri, bad_data, mode=mode, engine="rust") @@ -227,7 +227,7 @@ def test_overwrite_schema(existing_table: DeltaTable): def test_update_schema_rust_writer_append(existing_table: DeltaTable): - with pytest.raises(DeltaError): + with pytest.raises(ValueError): # It's illegal to do schema drift without correct schema_mode write_deltalake( existing_table, @@ -244,7 +244,7 @@ def test_update_schema_rust_writer_append(existing_table: DeltaTable): schema_mode="overwrite", engine="rust", ) - with pytest.raises(DeltaError): + with pytest.raises(ValueError): write_deltalake( existing_table, pa.table({"utf8": pa.array([1, 2, 3])}), @@ -263,7 +263,7 @@ def test_update_schema_rust_writer_append(existing_table: DeltaTable): def test_update_schema_rust_writer_invalid(existing_table: DeltaTable): new_data = pa.table({"x5": pa.array([1, 2, 3])}) - with pytest.raises(DeltaError): + with pytest.raises(ValueError): write_deltalake( existing_table, new_data, From 9b81041e54c46857e73873e18d1bcd9b98246702 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Tue, 5 Mar 2024 06:24:12 +0100 Subject: [PATCH 36/40] fmt --- crates/core/src/operations/write.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 2d3e7fa839..2d08082ea9 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -614,12 +614,10 @@ impl std::future::IntoFuture for WriteBuilder { if this.mode == SaveMode::Overwrite && this.schema_mode.is_some() { new_schema = None // we overwrite anyway, so no need to cast } else if this.schema_mode == Some(SchemaMode::Merge) { - new_schema = Some(Arc::new( - merge_schema( - table_schema.as_ref().clone(), - schema.as_ref().clone(), - )?, - )); + new_schema = Some(Arc::new(merge_schema( + table_schema.as_ref().clone(), + schema.as_ref().clone(), + )?)); } else { return Err(schema_err.into()); } From e3f8b8bfe5c2a92788bd59d9f0e16fb2aacf9231 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Tue, 5 Mar 2024 06:25:30 +0100 Subject: [PATCH 37/40] unused import --- crates/core/src/operations/cast.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index 96e4f4a849..33155dedd8 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -7,7 +7,6 @@ use arrow_schema::{ ArrowError, DataType, Field as ArrowField, Fields, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, }; -use futures::SinkExt; use std::sync::Arc; use crate::DeltaResult; From 15d4be39517eb54ac7d9363ef6d563ba3d930fdc Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Tue, 5 Mar 2024 10:05:48 +0100 Subject: [PATCH 38/40] Better exception handling --- python/deltalake/exceptions.py | 1 + python/src/error.rs | 3 ++- python/src/lib.rs | 3 ++- python/tests/test_writer.py | 32 +++++++++++++++++++++++++------- 4 files changed, 30 insertions(+), 9 deletions(-) diff --git a/python/deltalake/exceptions.py b/python/deltalake/exceptions.py index bacd0af9f8..a2e5b1ba1e 100644 --- a/python/deltalake/exceptions.py +++ b/python/deltalake/exceptions.py @@ -1,4 +1,5 @@ from ._internal import CommitFailedError as CommitFailedError from ._internal import DeltaError as DeltaError from ._internal import DeltaProtocolError as DeltaProtocolError +from ._internal import SchemaMismatchError as SchemaMismatchError from ._internal import TableNotFoundError as TableNotFoundError diff --git a/python/src/error.rs b/python/src/error.rs index 241d0702eb..a54b1e60b4 100644 --- a/python/src/error.rs +++ b/python/src/error.rs @@ -10,6 +10,7 @@ create_exception!(_internal, DeltaError, PyException); create_exception!(_internal, TableNotFoundError, DeltaError); create_exception!(_internal, DeltaProtocolError, DeltaError); create_exception!(_internal, CommitFailedError, DeltaError); +create_exception!(_internal, SchemaMismatchError, DeltaError); fn inner_to_py_err(err: DeltaTableError) -> PyErr { match err { @@ -55,7 +56,7 @@ fn arrow_to_py(err: ArrowError) -> PyErr { ArrowError::DivideByZero => PyValueError::new_err("division by zero"), ArrowError::InvalidArgumentError(msg) => PyValueError::new_err(msg), ArrowError::NotYetImplemented(msg) => PyNotImplementedError::new_err(msg), - ArrowError::SchemaError(msg) => PyValueError::new_err(msg), + ArrowError::SchemaError(msg) => SchemaMismatchError::new_err(msg), other => PyException::new_err(other.to_string()), } } diff --git a/python/src/lib.rs b/python/src/lib.rs index 65d0ba5944..0800d42927 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1624,7 +1624,7 @@ impl PyDeltaDataChecker { #[pymodule] // module name need to match project name fn _internal(py: Python, m: &PyModule) -> PyResult<()> { - use crate::error::{CommitFailedError, DeltaError, TableNotFoundError}; + use crate::error::{CommitFailedError, DeltaError, SchemaMismatchError, TableNotFoundError}; deltalake::aws::register_handlers(None); deltalake::azure::register_handlers(None); @@ -1634,6 +1634,7 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> { m.add("CommitFailedError", py.get_type::())?; m.add("DeltaProtocolError", py.get_type::())?; m.add("TableNotFoundError", py.get_type::())?; + m.add("SchemaMismatchError", py.get_type::())?; env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("warn")).init(); m.add("__version__", env!("CARGO_PKG_VERSION"))?; diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 417ccfbd18..0ee751c2d7 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -17,7 +17,12 @@ from pyarrow.lib import RecordBatchReader from deltalake import DeltaTable, Schema, write_deltalake -from deltalake.exceptions import CommitFailedError, DeltaError, DeltaProtocolError +from deltalake.exceptions import ( + CommitFailedError, + DeltaError, + DeltaProtocolError, + SchemaMismatchError, +) from deltalake.table import ProtocolVersions from deltalake.writer import try_get_table_and_table_uri @@ -124,11 +129,17 @@ def test_enforce_schema(existing_table: DeltaTable, mode: str): def test_enforce_schema_rust_writer(existing_table: DeltaTable, mode: str): bad_data = pa.table({"x": pa.array([1, 2, 3])}) - with pytest.raises(ValueError): + with pytest.raises( + SchemaMismatchError, + match=".*Cannot cast schema, number of fields does not match.*", + ): write_deltalake(existing_table, bad_data, mode=mode, engine="rust") table_uri = existing_table._table.table_uri() - with pytest.raises(ValueError): + with pytest.raises( + SchemaMismatchError, + match=".*Cannot cast schema, number of fields does not match.*", + ): write_deltalake(table_uri, bad_data, mode=mode, engine="rust") @@ -227,7 +238,9 @@ def test_overwrite_schema(existing_table: DeltaTable): def test_update_schema_rust_writer_append(existing_table: DeltaTable): - with pytest.raises(ValueError): + with pytest.raises( + SchemaMismatchError, match="Cannot cast schema, number of fields does not match" + ): # It's illegal to do schema drift without correct schema_mode write_deltalake( existing_table, @@ -237,14 +250,17 @@ def test_update_schema_rust_writer_append(existing_table: DeltaTable): engine="rust", ) with pytest.raises(DeltaError): - write_deltalake( + write_deltalake( # schema_mode overwrite is illegal with append existing_table, pa.table({"x1": pa.array([1, 2, 3])}), mode="append", schema_mode="overwrite", engine="rust", ) - with pytest.raises(ValueError): + with pytest.raises( + SchemaMismatchError, + match="Schema error: Fail to merge schema field 'utf8' because the from data_type = Int64 does not equal Utf8", + ): write_deltalake( existing_table, pa.table({"utf8": pa.array([1, 2, 3])}), @@ -263,7 +279,9 @@ def test_update_schema_rust_writer_append(existing_table: DeltaTable): def test_update_schema_rust_writer_invalid(existing_table: DeltaTable): new_data = pa.table({"x5": pa.array([1, 2, 3])}) - with pytest.raises(ValueError): + with pytest.raises( + SchemaMismatchError, match="Cannot cast schema, number of fields does not match" + ): write_deltalake( existing_table, new_data, From 2d3699945481feacc5ae67493977cf992e1c95cd Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Tue, 5 Mar 2024 10:09:41 +0100 Subject: [PATCH 39/40] commit missing file --- python/deltalake/_internal.pyi | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 4300be52de..695d7e3322 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -795,6 +795,11 @@ class DeltaProtocolError(DeltaError): pass +class SchemaMismatchError(DeltaError): + """Raised when a schema mismatch is detected.""" + + pass + FilterLiteralType = Tuple[str, str, Any] FilterConjunctionType = List[FilterLiteralType] FilterDNFType = List[FilterConjunctionType] From 82a0233bda3d67bc320be37da6db4e2aca0e8af5 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam Date: Tue, 5 Mar 2024 13:22:03 +0100 Subject: [PATCH 40/40] do not use 0.9.1 of object_store for now --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index a6819cde0b..348cfd5310 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ arrow-ord = { version = "50" } arrow-row = { version = "50" } arrow-schema = { version = "50" } arrow-select = { version = "50" } -object_store = { version = "0.9" } +object_store = { version = "=0.9.0" } parquet = { version = "50" } # datafusion