diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index 626fb1a5d9..2b4814f98b 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -21,6 +21,7 @@ ) from urllib.parse import unquote +from deltalake import Schema from deltalake.fs import DeltaStorageHandler from ._util import encode_partition_value @@ -81,7 +82,7 @@ def write_deltalake( RecordBatchReader, ], *, - schema: Optional[pa.Schema] = ..., + schema: Optional[Union[pa.Schema, Schema]] = ..., partition_by: Optional[Union[List[str], str]] = ..., filesystem: Optional[pa_fs.FileSystem] = None, mode: Literal["error", "append", "overwrite", "ignore"] = ..., @@ -115,7 +116,7 @@ def write_deltalake( RecordBatchReader, ], *, - schema: Optional[pa.Schema] = ..., + schema: Optional[Union[pa.Schema, Schema]] = ..., partition_by: Optional[Union[List[str], str]] = ..., mode: Literal["error", "append", "overwrite", "ignore"] = ..., max_rows_per_group: int = ..., @@ -142,7 +143,7 @@ def write_deltalake( RecordBatchReader, ], *, - schema: Optional[pa.Schema] = None, + schema: Optional[Union[pa.Schema, Schema]] = None, partition_by: Optional[Union[List[str], str]] = None, filesystem: Optional[pa_fs.FileSystem] = None, mode: Literal["error", "append", "overwrite", "ignore"] = "error", @@ -244,6 +245,9 @@ def write_deltalake( if isinstance(partition_by, str): partition_by = [partition_by] + if isinstance(schema, Schema): + schema = schema.to_pyarrow() + if isinstance(data, RecordBatchReader): data = convert_pyarrow_recordbatchreader(data, large_dtypes) elif isinstance(data, pa.RecordBatch): @@ -336,16 +340,19 @@ def _large_to_normal_dtype(dtype: pa.DataType) -> pa.DataType: return dtype if partition_by: + table_schema: pa.Schema = schema if PYARROW_MAJOR_VERSION < 12: partition_schema = pa.schema( [ - pa.field(name, _large_to_normal_dtype(schema.field(name).type)) + pa.field( + name, _large_to_normal_dtype(table_schema.field(name).type) + ) for name in partition_by ] ) else: partition_schema = pa.schema( - [schema.field(name) for name in partition_by] + [table_schema.field(name) for name in partition_by] ) partitioning = ds.partitioning(partition_schema, flavor="hive") else: diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 0a63b16c70..49177782ff 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -16,7 +16,7 @@ from pyarrow.dataset import ParquetFileFormat, ParquetReadOptions from pyarrow.lib import RecordBatchReader -from deltalake import DeltaTable, write_deltalake +from deltalake import DeltaTable, Schema, write_deltalake from deltalake.exceptions import CommitFailedError, DeltaError, DeltaProtocolError from deltalake.table import ProtocolVersions from deltalake.writer import try_get_table_and_table_uri @@ -1176,3 +1176,11 @@ def test_float_values(tmp_path: pathlib.Path): assert actions["min"].field("x2")[0].as_py() is None assert actions["max"].field("x2")[0].as_py() == 1.0 assert actions["null_count"].field("x2")[0].as_py() == 1 + + +def test_with_deltalake_schema(tmp_path: pathlib.Path, sample_data: pa.Table): + write_deltalake( + tmp_path, sample_data, schema=Schema.from_pyarrow(sample_data.schema) + ) + delta_table = DeltaTable(tmp_path) + assert delta_table.schema().to_pyarrow() == sample_data.schema