From 5a86db3c47d692584f694701780b0d8944a5e984 Mon Sep 17 00:00:00 2001 From: Frederic Branczyk Date: Fri, 15 Nov 2024 14:44:48 -0700 Subject: [PATCH] File writer preserve dict bug (#6711) * arrow-ipc: Add failing test for IPC file writer not preserving dict ID * arrow-ipc: Fix footer schema in IPC file --- arrow-ipc/src/reader.rs | 37 ++++++++++++++++++++++++++++++++++++- arrow-ipc/src/writer.rs | 5 ++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 0820e3590827..dcded32882fc 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -1395,7 +1395,7 @@ impl RecordBatchReader for StreamReader { #[cfg(test)] mod tests { - use crate::writer::{unslice_run_array, DictionaryTracker, IpcDataGenerator}; + use crate::writer::{unslice_run_array, DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; use super::*; @@ -1702,6 +1702,41 @@ mod tests { assert_eq!(batch, roundtrip_ipc(&batch)); } + #[test] + fn test_roundtrip_nested_dict_no_preserve_dict_id() { + let inner: DictionaryArray = vec!["a", "b", "a"].into_iter().collect(); + + let array = Arc::new(inner) as ArrayRef; + + let dctfield = Arc::new(Field::new("dict", array.data_type().clone(), false)); + + let s = StructArray::from(vec![(dctfield, array)]); + let struct_array = Arc::new(s) as ArrayRef; + + let schema = Arc::new(Schema::new(vec![Field::new( + "struct", + struct_array.data_type().clone(), + false, + )])); + + let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap(); + + let mut buf = Vec::new(); + let mut writer = crate::writer::FileWriter::try_new_with_options( + &mut buf, + batch.schema_ref(), + IpcWriteOptions::default().with_preserve_dict_id(false), + ) + .unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + drop(writer); + + let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None).unwrap(); + + assert_eq!(batch, reader.next().unwrap().unwrap()); + } + fn check_union_with_builder(mut builder: UnionBuilder) { builder.append::("a", 1).unwrap(); builder.append_null::("a").unwrap(); diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index b5c4dd95ed9f..e6fc9d81df67 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -1012,8 +1012,11 @@ impl FileWriter { let mut fbb = FlatBufferBuilder::new(); let dictionaries = fbb.create_vector(&self.dictionary_blocks); let record_batches = fbb.create_vector(&self.record_blocks); + let preserve_dict_id = self.write_options.preserve_dict_id; + let mut dictionary_tracker = + DictionaryTracker::new_with_preserve_dict_id(true, preserve_dict_id); let schema = IpcSchemaEncoder::new() - .with_dictionary_tracker(&mut self.dictionary_tracker) + .with_dictionary_tracker(&mut dictionary_tracker) .schema_to_fb_offset(&mut fbb, &self.schema); let fb_custom_metadata = (!self.custom_metadata.is_empty()) .then(|| crate::convert::metadata_to_fb(&mut fbb, &self.custom_metadata));