Skip to content

Commit

Permalink
Read and skip validity buffer of UnionType Array for V4 ipc message (#…
Browse files Browse the repository at this point in the history
…1789)

* Read valididy buffer for V4 ipc message

* Add unit test

* Fix clippy
  • Loading branch information
viirya authored Jun 5, 2022
1 parent 940b5b5 commit 73d552a
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 5 deletions.
1 change: 1 addition & 0 deletions arrow-flight/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub fn flight_data_to_arrow_batch(
schema,
dictionaries_by_id,
None,
&message.version(),
)
})?
}
Expand Down
31 changes: 28 additions & 3 deletions arrow/src/ipc/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ fn read_buffer(buf: &ipc::Buffer, a_data: &[u8]) -> Buffer {
/// - check if the bit width of non-64-bit numbers is 64, and
/// - read the buffer as 64-bit (signed integer or float), and
/// - cast the 64-bit array to the appropriate data type
#[allow(clippy::too_many_arguments)]
fn create_array(
nodes: &[ipc::FieldNode],
field: &Field,
Expand All @@ -60,6 +61,7 @@ fn create_array(
dictionaries_by_id: &HashMap<i64, ArrayRef>,
mut node_index: usize,
mut buffer_index: usize,
metadata: &ipc::MetadataVersion,
) -> Result<(ArrayRef, usize, usize)> {
use DataType::*;
let data_type = field.data_type();
Expand Down Expand Up @@ -106,6 +108,7 @@ fn create_array(
dictionaries_by_id,
node_index,
buffer_index,
metadata,
)?;
node_index = triple.1;
buffer_index = triple.2;
Expand All @@ -128,6 +131,7 @@ fn create_array(
dictionaries_by_id,
node_index,
buffer_index,
metadata,
)?;
node_index = triple.1;
buffer_index = triple.2;
Expand All @@ -153,6 +157,7 @@ fn create_array(
dictionaries_by_id,
node_index,
buffer_index,
metadata,
)?;
node_index = triple.1;
buffer_index = triple.2;
Expand Down Expand Up @@ -201,6 +206,13 @@ fn create_array(

let len = union_node.length() as usize;

// In V4, union types has validity bitmap
// In V5 and later, union types have no validity bitmap
if metadata < &ipc::MetadataVersion::V5 {
read_buffer(&buffers[buffer_index], data);
buffer_index += 1;
}

let type_ids: Buffer =
read_buffer(&buffers[buffer_index], data)[..len].into();

Expand All @@ -226,6 +238,7 @@ fn create_array(
dictionaries_by_id,
node_index,
buffer_index,
metadata,
)?;

node_index = triple.1;
Expand Down Expand Up @@ -582,6 +595,7 @@ pub fn read_record_batch(
schema: SchemaRef,
dictionaries_by_id: &HashMap<i64, ArrayRef>,
projection: Option<&[usize]>,
metadata: &ipc::MetadataVersion,
) -> Result<RecordBatch> {
let buffers = batch.buffers().ok_or_else(|| {
ArrowError::IoError("Unable to get buffers from IPC RecordBatch".to_string())
Expand All @@ -607,6 +621,7 @@ pub fn read_record_batch(
dictionaries_by_id,
node_index,
buffer_index,
metadata,
)?;
node_index = triple.1;
buffer_index = triple.2;
Expand Down Expand Up @@ -640,6 +655,7 @@ pub fn read_record_batch(
dictionaries_by_id,
node_index,
buffer_index,
metadata,
)?;
node_index = triple.1;
buffer_index = triple.2;
Expand All @@ -656,6 +672,7 @@ pub fn read_dictionary(
batch: ipc::DictionaryBatch,
schema: &Schema,
dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
metadata: &ipc::MetadataVersion,
) -> Result<()> {
if batch.isDelta() {
return Err(ArrowError::IoError(
Expand Down Expand Up @@ -686,6 +703,7 @@ pub fn read_dictionary(
Arc::new(schema),
dictionaries_by_id,
None,
metadata,
)?;
Some(record_batch.column(0).clone())
}
Expand Down Expand Up @@ -816,7 +834,13 @@ impl<R: Read + Seek> FileReader<R> {
))?;
reader.read_exact(&mut buf)?;

read_dictionary(&buf, batch, &schema, &mut dictionaries_by_id)?;
read_dictionary(
&buf,
batch,
&schema,
&mut dictionaries_by_id,
&message.version(),
)?;
}
t => {
return Err(ArrowError::IoError(format!(
Expand Down Expand Up @@ -925,6 +949,7 @@ impl<R: Read + Seek> FileReader<R> {
self.schema(),
&self.dictionaries_by_id,
self.projection.as_ref().map(|x| x.0.as_ref()),
&message.version()

).map(Some)
}
Expand Down Expand Up @@ -1099,7 +1124,7 @@ impl<R: Read> StreamReader<R> {
let mut buf = vec![0; message.bodyLength() as usize];
self.reader.read_exact(&mut buf)?;

read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_id, self.projection.as_ref().map(|x| x.0.as_ref())).map(Some)
read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_id, self.projection.as_ref().map(|x| x.0.as_ref()), &message.version()).map(Some)
}
ipc::MessageHeader::DictionaryBatch => {
let batch = message.header_as_dictionary_batch().ok_or_else(|| {
Expand All @@ -1112,7 +1137,7 @@ impl<R: Read> StreamReader<R> {
self.reader.read_exact(&mut buf)?;

read_dictionary(
&buf, batch, &self.schema, &mut self.dictionaries_by_id
&buf, batch, &self.schema, &mut self.dictionaries_by_id, &message.version()
)?;

// read the next message until we encounter a RecordBatch
Expand Down
48 changes: 48 additions & 0 deletions arrow/src/ipc/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1385,4 +1385,52 @@ mod tests {
// Dictionary with id 2 should have been written to the dict tracker
assert!(dict_tracker.written.contains_key(&2));
}

#[test]
fn read_union_017() {
let testdata = crate::util::test_util::arrow_test_data();
let version = "0.17.1";
let data_file = File::open(format!(
"{}/arrow-ipc-stream/integration/0.17.1/generated_union.stream",
testdata,
))
.unwrap();

let reader = StreamReader::try_new(data_file, None).unwrap();

// read and rewrite the stream to a temp location
{
let file = File::create(format!(
"target/debug/testdata/{}-generated_union.stream",
version
))
.unwrap();
let mut writer = StreamWriter::try_new(file, &reader.schema()).unwrap();
reader.for_each(|batch| {
writer.write(&batch.unwrap()).unwrap();
});
writer.finish().unwrap();
}

// Compare original file and rewrote file
let file = File::open(format!(
"target/debug/testdata/{}-generated_union.stream",
version
))
.unwrap();
let rewrite_reader = StreamReader::try_new(file, None).unwrap();

let data_file = File::open(format!(
"{}/arrow-ipc-stream/integration/0.17.1/generated_union.stream",
testdata,
))
.unwrap();
let reader = StreamReader::try_new(data_file, None).unwrap();

reader.into_iter().zip(rewrite_reader.into_iter()).for_each(
|(batch1, batch2)| {
assert_eq!(batch1.unwrap(), batch2.unwrap());
},
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ async fn receive_batch_flight_data(
.expect("Error parsing dictionary"),
&schema,
dictionaries_by_id,
&message.version(),
)
.expect("Error reading dictionary");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ async fn record_batch_from_message(
schema_ref,
dictionaries_by_id,
None,
&message.version(),
);

arrow_batch_result.map_err(|e| {
Expand All @@ -313,8 +314,13 @@ async fn dictionary_from_message(
Status::internal("Could not parse message header as dictionary batch")
})?;

let dictionary_batch_result =
reader::read_dictionary(data_body, ipc_batch, &schema_ref, dictionaries_by_id);
let dictionary_batch_result = reader::read_dictionary(
data_body,
ipc_batch,
&schema_ref,
dictionaries_by_id,
&message.version(),
);
dictionary_batch_result.map_err(|e| {
Status::internal(format!("Could not convert to Dictionary: {:?}", e))
})
Expand Down

0 comments on commit 73d552a

Please sign in to comment.