From f12834e22ff01d6a6496f8e7ad0a16f3f76d4b7c Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Mon, 22 Apr 2024 19:02:00 +0200 Subject: [PATCH] fix(python,rust): missing remove actions during `create_or_replace` (#2437) # Description The overwrite mode never added the remove actions, which causes your table to get in an invalid state. --- crates/core/src/operations/create.rs | 60 +++++++++++++++++++++++++++- python/tests/test_create.py | 13 +++++- 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/crates/core/src/operations/create.rs b/crates/core/src/operations/create.rs index b2092cafe8..f9a7f62183 100644 --- a/crates/core/src/operations/create.rs +++ b/crates/core/src/operations/create.rs @@ -310,6 +310,7 @@ impl CreateBuilder { }; let mut actions = vec![Action::Protocol(protocol), Action::Metadata(metadata)]; + actions.extend( self.actions .into_iter() @@ -329,7 +330,7 @@ impl std::future::IntoFuture for CreateBuilder { Box::pin(async move { let mode = this.mode; let app_metadata = this.metadata.clone().unwrap_or_default(); - let (mut table, actions, operation) = this.into_table_and_actions()?; + let (mut table, mut actions, operation) = this.into_table_and_actions()?; let log_store = table.log_store(); let table_state = if log_store.is_delta_table_location().await? { @@ -342,6 +343,12 @@ impl std::future::IntoFuture for CreateBuilder { } SaveMode::Overwrite => { table.load().await?; + let remove_actions = table + .snapshot()? + .log_data() + .into_iter() + .map(|p| p.remove_action(true).into()); + actions.extend(remove_actions); Some(table.snapshot()?) } } @@ -371,7 +378,7 @@ mod tests { use super::*; use crate::operations::DeltaOps; use crate::table::config::DeltaConfigKey; - use crate::writer::test_utils::get_delta_schema; + use crate::writer::test_utils::{get_delta_schema, get_record_batch}; use tempfile::TempDir; #[tokio::test] @@ -518,4 +525,53 @@ mod tests { .unwrap(); assert_ne!(table.metadata().unwrap().id, first_id) } + + #[tokio::test] + async fn test_create_or_replace_existing_table() { + let batch = get_record_batch(None, false); + let schema = get_delta_schema(); + let table = DeltaOps::new_in_memory() + .write(vec![batch.clone()]) + .with_save_mode(SaveMode::ErrorIfExists) + .await + .unwrap(); + assert_eq!(table.version(), 0); + assert_eq!(table.get_files_count(), 1); + + let mut table = DeltaOps(table) + .create() + .with_columns(schema.fields().iter().cloned()) + .with_save_mode(SaveMode::Overwrite) + .await + .unwrap(); + table.load().await.unwrap(); + assert_eq!(table.version(), 1); + /// Checks if files got removed after overwrite + assert_eq!(table.get_files_count(), 0); + } + + #[tokio::test] + async fn test_create_or_replace_existing_table_partitioned() { + let batch = get_record_batch(None, false); + let schema = get_delta_schema(); + let table = DeltaOps::new_in_memory() + .write(vec![batch.clone()]) + .with_save_mode(SaveMode::ErrorIfExists) + .await + .unwrap(); + assert_eq!(table.version(), 0); + assert_eq!(table.get_files_count(), 1); + + let mut table = DeltaOps(table) + .create() + .with_columns(schema.fields().iter().cloned()) + .with_save_mode(SaveMode::Overwrite) + .with_partition_columns(vec!["id"]) + .await + .unwrap(); + table.load().await.unwrap(); + assert_eq!(table.version(), 1); + /// Checks if files got removed after overwrite + assert_eq!(table.get_files_count(), 0); + } } diff --git a/python/tests/test_create.py b/python/tests/test_create.py index 3852fc2bab..ceca8178c3 100644 --- a/python/tests/test_create.py +++ b/python/tests/test_create.py @@ -3,7 +3,7 @@ import pyarrow as pa import pytest -from deltalake import DeltaTable +from deltalake import DeltaTable, write_deltalake from deltalake.exceptions import DeltaError @@ -54,3 +54,14 @@ def test_create_schema(tmp_path: pathlib.Path, sample_data: pa.Table): ) assert dt.schema().to_pyarrow() == sample_data.schema + + +def test_create_or_replace_existing_table( + tmp_path: pathlib.Path, sample_data: pa.Table +): + write_deltalake(table_or_uri=tmp_path, data=sample_data) + dt = DeltaTable.create( + tmp_path, sample_data.schema, partition_by=["utf8"], mode="overwrite" + ) + + assert dt.files() == []