From 73d552a7cc794d0e3eaa3e5333e5bc1c98deeb45 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 5 Jun 2022 02:00:44 -0700 Subject: [PATCH] Read and skip validity buffer of UnionType Array for V4 ipc message (#1789) * Read valididy buffer for V4 ipc message * Add unit test * Fix clippy --- arrow-flight/src/utils.rs | 1 + arrow/src/ipc/reader.rs | 31 ++++++++++-- arrow/src/ipc/writer.rs | 48 +++++++++++++++++++ .../integration_test.rs | 1 + .../integration_test.rs | 10 +++- 5 files changed, 86 insertions(+), 5 deletions(-) diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index 77526917f22a..dda3fc7fe3db 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -71,6 +71,7 @@ pub fn flight_data_to_arrow_batch( schema, dictionaries_by_id, None, + &message.version(), ) })? } diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs index 03a960c4c670..868098327092 100644 --- a/arrow/src/ipc/reader.rs +++ b/arrow/src/ipc/reader.rs @@ -52,6 +52,7 @@ fn read_buffer(buf: &ipc::Buffer, a_data: &[u8]) -> Buffer { /// - check if the bit width of non-64-bit numbers is 64, and /// - read the buffer as 64-bit (signed integer or float), and /// - cast the 64-bit array to the appropriate data type +#[allow(clippy::too_many_arguments)] fn create_array( nodes: &[ipc::FieldNode], field: &Field, @@ -60,6 +61,7 @@ fn create_array( dictionaries_by_id: &HashMap, mut node_index: usize, mut buffer_index: usize, + metadata: &ipc::MetadataVersion, ) -> Result<(ArrayRef, usize, usize)> { use DataType::*; let data_type = field.data_type(); @@ -106,6 +108,7 @@ fn create_array( dictionaries_by_id, node_index, buffer_index, + metadata, )?; node_index = triple.1; buffer_index = triple.2; @@ -128,6 +131,7 @@ fn create_array( dictionaries_by_id, node_index, buffer_index, + metadata, )?; node_index = triple.1; buffer_index = triple.2; @@ -153,6 +157,7 @@ fn create_array( dictionaries_by_id, node_index, buffer_index, + metadata, )?; node_index = triple.1; buffer_index = triple.2; @@ -201,6 +206,13 @@ fn create_array( let len = union_node.length() as usize; + // In V4, union types has validity bitmap + // In V5 and later, union types have no validity bitmap + if metadata < &ipc::MetadataVersion::V5 { + read_buffer(&buffers[buffer_index], data); + buffer_index += 1; + } + let type_ids: Buffer = read_buffer(&buffers[buffer_index], data)[..len].into(); @@ -226,6 +238,7 @@ fn create_array( dictionaries_by_id, node_index, buffer_index, + metadata, )?; node_index = triple.1; @@ -582,6 +595,7 @@ pub fn read_record_batch( schema: SchemaRef, dictionaries_by_id: &HashMap, projection: Option<&[usize]>, + metadata: &ipc::MetadataVersion, ) -> Result { let buffers = batch.buffers().ok_or_else(|| { ArrowError::IoError("Unable to get buffers from IPC RecordBatch".to_string()) @@ -607,6 +621,7 @@ pub fn read_record_batch( dictionaries_by_id, node_index, buffer_index, + metadata, )?; node_index = triple.1; buffer_index = triple.2; @@ -640,6 +655,7 @@ pub fn read_record_batch( dictionaries_by_id, node_index, buffer_index, + metadata, )?; node_index = triple.1; buffer_index = triple.2; @@ -656,6 +672,7 @@ pub fn read_dictionary( batch: ipc::DictionaryBatch, schema: &Schema, dictionaries_by_id: &mut HashMap, + metadata: &ipc::MetadataVersion, ) -> Result<()> { if batch.isDelta() { return Err(ArrowError::IoError( @@ -686,6 +703,7 @@ pub fn read_dictionary( Arc::new(schema), dictionaries_by_id, None, + metadata, )?; Some(record_batch.column(0).clone()) } @@ -816,7 +834,13 @@ impl FileReader { ))?; reader.read_exact(&mut buf)?; - read_dictionary(&buf, batch, &schema, &mut dictionaries_by_id)?; + read_dictionary( + &buf, + batch, + &schema, + &mut dictionaries_by_id, + &message.version(), + )?; } t => { return Err(ArrowError::IoError(format!( @@ -925,6 +949,7 @@ impl FileReader { self.schema(), &self.dictionaries_by_id, self.projection.as_ref().map(|x| x.0.as_ref()), + &message.version() ).map(Some) } @@ -1099,7 +1124,7 @@ impl StreamReader { let mut buf = vec![0; message.bodyLength() as usize]; self.reader.read_exact(&mut buf)?; - read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_id, self.projection.as_ref().map(|x| x.0.as_ref())).map(Some) + read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_id, self.projection.as_ref().map(|x| x.0.as_ref()), &message.version()).map(Some) } ipc::MessageHeader::DictionaryBatch => { let batch = message.header_as_dictionary_batch().ok_or_else(|| { @@ -1112,7 +1137,7 @@ impl StreamReader { self.reader.read_exact(&mut buf)?; read_dictionary( - &buf, batch, &self.schema, &mut self.dictionaries_by_id + &buf, batch, &self.schema, &mut self.dictionaries_by_id, &message.version() )?; // read the next message until we encounter a RecordBatch diff --git a/arrow/src/ipc/writer.rs b/arrow/src/ipc/writer.rs index c42c0fd97e7d..70e07acae98c 100644 --- a/arrow/src/ipc/writer.rs +++ b/arrow/src/ipc/writer.rs @@ -1385,4 +1385,52 @@ mod tests { // Dictionary with id 2 should have been written to the dict tracker assert!(dict_tracker.written.contains_key(&2)); } + + #[test] + fn read_union_017() { + let testdata = crate::util::test_util::arrow_test_data(); + let version = "0.17.1"; + let data_file = File::open(format!( + "{}/arrow-ipc-stream/integration/0.17.1/generated_union.stream", + testdata, + )) + .unwrap(); + + let reader = StreamReader::try_new(data_file, None).unwrap(); + + // read and rewrite the stream to a temp location + { + let file = File::create(format!( + "target/debug/testdata/{}-generated_union.stream", + version + )) + .unwrap(); + let mut writer = StreamWriter::try_new(file, &reader.schema()).unwrap(); + reader.for_each(|batch| { + writer.write(&batch.unwrap()).unwrap(); + }); + writer.finish().unwrap(); + } + + // Compare original file and rewrote file + let file = File::open(format!( + "target/debug/testdata/{}-generated_union.stream", + version + )) + .unwrap(); + let rewrite_reader = StreamReader::try_new(file, None).unwrap(); + + let data_file = File::open(format!( + "{}/arrow-ipc-stream/integration/0.17.1/generated_union.stream", + testdata, + )) + .unwrap(); + let reader = StreamReader::try_new(data_file, None).unwrap(); + + reader.into_iter().zip(rewrite_reader.into_iter()).for_each( + |(batch1, batch2)| { + assert_eq!(batch1.unwrap(), batch2.unwrap()); + }, + ); + } } diff --git a/integration-testing/src/flight_client_scenarios/integration_test.rs b/integration-testing/src/flight_client_scenarios/integration_test.rs index 4158a7352140..62fe2b85d262 100644 --- a/integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/integration-testing/src/flight_client_scenarios/integration_test.rs @@ -270,6 +270,7 @@ async fn receive_batch_flight_data( .expect("Error parsing dictionary"), &schema, dictionaries_by_id, + &message.version(), ) .expect("Error reading dictionary"); diff --git a/integration-testing/src/flight_server_scenarios/integration_test.rs b/integration-testing/src/flight_server_scenarios/integration_test.rs index 52086aade748..7ad3d18eb5ba 100644 --- a/integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/integration-testing/src/flight_server_scenarios/integration_test.rs @@ -296,6 +296,7 @@ async fn record_batch_from_message( schema_ref, dictionaries_by_id, None, + &message.version(), ); arrow_batch_result.map_err(|e| { @@ -313,8 +314,13 @@ async fn dictionary_from_message( Status::internal("Could not parse message header as dictionary batch") })?; - let dictionary_batch_result = - reader::read_dictionary(data_body, ipc_batch, &schema_ref, dictionaries_by_id); + let dictionary_batch_result = reader::read_dictionary( + data_body, + ipc_batch, + &schema_ref, + dictionaries_by_id, + &message.version(), + ); dictionary_batch_result.map_err(|e| { Status::internal(format!("Could not convert to Dictionary: {:?}", e)) })