From adf43f059687fa8dc04bba472b97c66128594dc0 Mon Sep 17 00:00:00 2001 From: ion-elgreco <15728914+ion-elgreco@users.noreply.github.com> Date: Sat, 9 Mar 2024 17:48:01 +0100 Subject: [PATCH] prevent reading to return large arrow schema --- .../core/src/operations/transaction/state.rs | 14 ++++++--- python/tests/test_delete.py | 24 ++++++++++++++ python/tests/test_update.py | 31 +++++++++++++++++++ python/tests/test_writer.py | 7 ++++- 4 files changed, 70 insertions(+), 6 deletions(-) diff --git a/crates/core/src/operations/transaction/state.rs b/crates/core/src/operations/transaction/state.rs index 993d60f0eb..8f21018364 100644 --- a/crates/core/src/operations/transaction/state.rs +++ b/crates/core/src/operations/transaction/state.rs @@ -11,6 +11,7 @@ use datafusion_common::Column; use datafusion_expr::Expr; use itertools::Itertools; use object_store::ObjectStore; +use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::async_reader::{ParquetObjectReader, ParquetRecordBatchStreamBuilder}; use crate::delta_datafusion::{ @@ -37,11 +38,14 @@ impl DeltaTableState { { let file_meta = add.try_into()?; let file_reader = ParquetObjectReader::new(object_store, file_meta); - let file_schema = ParquetRecordBatchStreamBuilder::new(file_reader) - .await? - .build()? - .schema() - .clone(); + let file_schema = ParquetRecordBatchStreamBuilder::new_with_options( + file_reader, + ArrowReaderOptions::new().with_skip_arrow_metadata(true), + ) + .await? + .build()? + .schema() + .clone(); let table_schema = Arc::new(ArrowSchema::new( self.arrow_schema()? diff --git a/python/tests/test_delete.py b/python/tests/test_delete.py index 4d3983a532..519af0c935 100644 --- a/python/tests/test_delete.py +++ b/python/tests/test_delete.py @@ -2,6 +2,7 @@ import pyarrow as pa import pyarrow.compute as pc +import pytest from deltalake.table import DeltaTable from deltalake.writer import write_deltalake @@ -57,3 +58,26 @@ def test_delete_some_rows(existing_table: DeltaTable): table = existing_table.to_pyarrow_table() assert table.equals(expected_table) + + +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_delete_large_dtypes( + tmp_path: pathlib.Path, sample_table: pa.table, engine: str +): + write_deltalake(tmp_path, sample_table, large_dtypes=True, engine=engine) # type: ignore + + dt = DeltaTable(tmp_path) + old_version = dt.version() + + existing = dt.to_pyarrow_table() + mask = pc.invert(pc.is_in(existing["id"], pa.array(["1"]))) + expected_table = existing.filter(mask) + + dt.delete(predicate="id = '1'") + + last_action = dt.history(1)[0] + assert last_action["operation"] == "DELETE" + assert dt.version() == old_version + 1 + + table = dt.to_pyarrow_table() + assert table.equals(expected_table) diff --git a/python/tests/test_update.py b/python/tests/test_update.py index c14f374937..fcc17cf027 100644 --- a/python/tests/test_update.py +++ b/python/tests/test_update.py @@ -52,6 +52,37 @@ def test_update_with_predicate(tmp_path: pathlib.Path, sample_table: pa.Table): assert result == expected +def test_update_with_predicate_large_dtypes( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append", large_dtypes=True) + + dt = DeltaTable(tmp_path) + + nrows = 5 + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array(list(range(nrows)), pa.int64()), + "sold": pa.array(list(range(nrows)), pa.int64()), + "price_float": pa.array(list(range(nrows)), pa.float64()), + "items_in_bucket": pa.array([["item1", "item2", "item3"]] * nrows), + "deleted": pa.array([True, False, False, False, False]), + } + ) + + dt.update( + updates={"deleted": "True"}, + predicate="id = '1'", + ) + + result = dt.to_pyarrow_table() + last_action = dt.history(1)[0] + + assert last_action["operation"] == "UPDATE" + assert result == expected + + def test_update_wo_predicate(tmp_path: pathlib.Path, sample_table: pa.Table): write_deltalake(tmp_path, sample_table, mode="append") diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 9d3dbeb77f..42f0cd825e 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -994,6 +994,7 @@ def test_partition_overwrite_unfiltered_data_fails( ) +@pytest.mark.parametrize("large_dtypes", [True, False]) @pytest.mark.parametrize( "value_1,value_2,value_type,filter_string", [ @@ -1008,6 +1009,7 @@ def test_replace_where_overwrite( value_2: Any, value_type: pa.DataType, filter_string: str, + large_dtypes: bool, ): table_path = tmp_path @@ -1018,7 +1020,9 @@ def test_replace_where_overwrite( "val": pa.array([1, 1, 1, 1], pa.int64()), } ) - write_deltalake(table_path, sample_data, mode="overwrite") + write_deltalake( + table_path, sample_data, mode="overwrite", large_dtypes=large_dtypes + ) delta_table = DeltaTable(table_path) assert ( @@ -1049,6 +1053,7 @@ def test_replace_where_overwrite( mode="overwrite", predicate="p1 = '1'", engine="rust", + large_dtypes=large_dtypes, ) delta_table.update_incremental()