Skip to content

Commit

Permalink
Ensure Parquet schema metadata is added to arrow table (#137)
Browse files Browse the repository at this point in the history
* Add parquet schema metadata to arrow table

* Better findings

* fix exporting metadata

* Add tests for metadata preservation via ffi

* smaller diff

* Update parquet schema metadata

* Update parquet test

* update comment

* ensure valid with column projection
  • Loading branch information
kylebarron authored Aug 15, 2024
1 parent e0088bc commit fcdf5b8
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.parquet
*.whl

# Generated by Cargo
Expand Down
4 changes: 3 additions & 1 deletion arro3-io/python/arro3/io/_io.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def read_parquet(file: Path | str) -> core.RecordBatchReader:

def write_parquet(
data: types.ArrowStreamExportable | types.ArrowArrayExportable,
file: IO[bytes] | Path | str,
file: str,
*,
bloom_filter_enabled: bool | None = None,
bloom_filter_fpp: float | None = None,
Expand All @@ -274,6 +274,7 @@ def write_parquet(
key_value_metadata: dict[str, str] | None = None,
max_row_group_size: int | None = None,
max_statistics_size: int | None = None,
skip_arrow_metadata: bool = False,
write_batch_size: int | None = None,
writer_version: Literal["parquet_1_0", "parquet_2_0"] | None = None,
) -> None:
Expand Down Expand Up @@ -338,6 +339,7 @@ def write_parquet(
key_value_metadata: Sets "key_value_metadata" property (defaults to `None`).
max_row_group_size: Sets maximum number of rows in a row group (defaults to `1024 * 1024`).
max_statistics_size: Sets default max statistics size for all columns (defaults to `4096`).
skip_arrow_metadata: Parquet files generated by this writer contain embedded arrow schema by default. Set `skip_arrow_metadata` to `True`, to skip encoding the embedded metadata (defaults to `False`).
write_batch_size:
Sets write batch size (defaults to 1024).
Expand Down
31 changes: 27 additions & 4 deletions arro3-io/src/parquet.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::collections::HashMap;
use std::fs::File;
use std::str::FromStr;
use std::sync::Arc;

use arrow_array::{RecordBatchIterator, RecordBatchReader};
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use parquet::arrow::arrow_writer::ArrowWriterOptions;
use parquet::arrow::ArrowWriter;
Expand All @@ -14,16 +17,30 @@ use pyo3_arrow::error::PyArrowResult;
use pyo3_arrow::input::AnyRecordBatch;
use pyo3_arrow::PyRecordBatchReader;

use crate::utils::{FileReader, FileWriter};
use crate::utils::FileReader;

#[pyfunction]
pub fn read_parquet(py: Python, file: FileReader) -> PyArrowResult<PyObject> {
match file {
FileReader::File(f) => {
let builder = ParquetRecordBatchReaderBuilder::try_new(f).unwrap();

let metadata = builder.schema().metadata().clone();
let reader = builder.build().unwrap();
Ok(PyRecordBatchReader::new(Box::new(reader)).to_arro3(py)?)

// Add source schema metadata onto reader's schema. The original schema is not valid
// with a given column projection, but we want to persist the source's metadata.
let arrow_schema = Arc::new(reader.schema().as_ref().clone().with_metadata(metadata));

// Create a new iterator with the arrow schema specifically
//
// Passing ParquetRecordBatchReader directly to PyRecordBatchReader::new loses schema
// metadata
//
// https://docs.rs/parquet/latest/parquet/arrow/arrow_reader/struct.ParquetRecordBatchReader.html#method.schema
// https://github.com/apache/arrow-rs/pull/5135
let iter = Box::new(RecordBatchIterator::new(reader, arrow_schema));
Ok(PyRecordBatchReader::new(iter).to_arro3(py)?)
}
FileReader::FileLike(_) => {
Err(PyTypeError::new_err("File objects not yet supported for reading parquet").into())
Expand Down Expand Up @@ -105,13 +122,14 @@ impl<'py> FromPyObject<'py> for PyColumnPath {
key_value_metadata = None,
max_row_group_size = None,
max_statistics_size = None,
skip_arrow_metadata = false,
write_batch_size = None,
writer_version = None,
))]
#[allow(clippy::too_many_arguments)]
pub(crate) fn write_parquet(
data: AnyRecordBatch,
file: FileWriter,
file: String,
bloom_filter_enabled: Option<bool>,
bloom_filter_fpp: Option<f64>,
bloom_filter_ndv: Option<u64>,
Expand All @@ -129,9 +147,12 @@ pub(crate) fn write_parquet(
key_value_metadata: Option<HashMap<String, String>>,
max_row_group_size: Option<usize>,
max_statistics_size: Option<usize>,
skip_arrow_metadata: bool,
write_batch_size: Option<usize>,
writer_version: Option<PyWriterVersion>,
) -> PyArrowResult<()> {
let file = File::create(file).map_err(|err| PyValueError::new_err(err.to_string()))?;

let mut props = WriterProperties::builder();

if let Some(writer_version) = writer_version {
Expand Down Expand Up @@ -207,7 +228,9 @@ pub(crate) fn write_parquet(

let reader = data.into_reader()?;

let writer_options = ArrowWriterOptions::new().with_properties(props.build());
let writer_options = ArrowWriterOptions::new()
.with_properties(props.build())
.with_skip_arrow_metadata(skip_arrow_metadata);
let mut writer =
ArrowWriter::try_new_with_options(file, reader.schema(), writer_options).unwrap();
for batch in reader {
Expand Down
4 changes: 3 additions & 1 deletion pyo3-arrow/src/record_batch_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ impl PyRecordBatchReader {
});
let array_reader = Box::new(ArrayIterator::new(
array_reader,
Field::new_struct("", schema.fields().clone(), false).into(),
Field::new_struct("", schema.fields().clone(), false)
.with_metadata(schema.metadata.clone())
.into(),
));
to_stream_pycapsule(py, array_reader, requested_schema)
}
Expand Down
4 changes: 3 additions & 1 deletion pyo3-arrow/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ impl PyTable {
});
let array_reader = Box::new(ArrayIterator::new(
array_reader,
Field::new_struct("", field, false).into(),
Field::new_struct("", field, false)
.with_metadata(self.schema.metadata.clone())
.into(),
));
to_stream_pycapsule(py, array_reader, requested_schema)
}
Expand Down
25 changes: 25 additions & 0 deletions tests/core/test_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,28 @@ def test_array_export_schema_request():

retour = Array.from_arrow_pycapsule(*capsules)
assert retour.type == DataType.large_utf8()


def test_table_metadata_preserved():
metadata = {b"hello": b"world"}
pa_table = pa.table({"a": [1, 2, 3]})
pa_table = pa_table.replace_schema_metadata(metadata)

arro3_table = Table(pa_table)
assert arro3_table.schema.metadata == metadata

pa_table_retour = pa.table(arro3_table)
assert pa_table_retour.schema.metadata == metadata


def test_record_batch_reader_metadata_preserved():
metadata = {b"hello": b"world"}
pa_table = pa.table({"a": [1, 2, 3]})
pa_table = pa_table.replace_schema_metadata(metadata)
pa_reader = pa.RecordBatchReader.from_stream(pa_table)

arro3_reader = RecordBatchReader.from_stream(pa_reader)
assert arro3_reader.schema.metadata == metadata

pa_reader_retour = pa.RecordBatchReader.from_stream(arro3_reader)
assert pa_reader_retour.schema.metadata == metadata
26 changes: 26 additions & 0 deletions tests/io/test_parquet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pyarrow as pa
import pyarrow.parquet as pq
from arro3.io import read_parquet, write_parquet


def test_copy_parquet_kv_metadata():
metadata = {"hello": "world"}
table = pa.table({"a": [1, 2, 3]})
write_parquet(
table,
"test.parquet",
key_value_metadata=metadata,
skip_arrow_metadata=True,
)

# Assert metadata was written, but arrow schema was not
pq_meta = pq.read_metadata("test.parquet").metadata
assert pq_meta[b"hello"] == b"world"
assert b"ARROW:schema" not in pq_meta.keys()

# When reading with pyarrow, kv meta gets assigned to table
pa_table = pq.read_table("test.parquet")
assert pa_table.schema.metadata[b"hello"] == b"world"

reader = read_parquet("test.parquet")
assert reader.schema.metadata[b"hello"] == b"world"

0 comments on commit fcdf5b8

Please sign in to comment.