Skip to content

Commit

Permalink
Fix projection in IPC reader (#1736)
Browse files Browse the repository at this point in the history
* Fix projection in IPC reader

* Add test for projection of IPC reader

* Fix clippy error

* Fix typos

* Improve test for projection in IPC reader
  • Loading branch information
iyupeng authored May 26, 2022
1 parent 2ba1ef4 commit 7391710
Showing 1 changed file with 309 additions and 16 deletions.
325 changes: 309 additions & 16 deletions arrow/src/ipc/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,120 @@ fn create_array(
Ok((array, node_index, buffer_index))
}

/// Skip fields based on data types to advance `node_index` and `buffer_index`.
/// This function should be called when doing projection in fn `read_record_batch`.
/// The advancement logic references fn `create_array`.
fn skip_field(
nodes: &[ipc::FieldNode],
field: &Field,
data: &[u8],
buffers: &[ipc::Buffer],
dictionaries_by_id: &HashMap<i64, ArrayRef>,
mut node_index: usize,
mut buffer_index: usize,
) -> Result<(usize, usize)> {
use DataType::*;
let data_type = field.data_type();
match data_type {
Utf8 | Binary | LargeBinary | LargeUtf8 => {
node_index += 1;
buffer_index += 3;
}
FixedSizeBinary(_) => {
node_index += 1;
buffer_index += 2;
}
List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => {
node_index += 1;
buffer_index += 2;
let tuple = skip_field(
nodes,
list_field,
data,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
node_index = tuple.0;
buffer_index = tuple.1;
}
FixedSizeList(ref list_field, _) => {
node_index += 1;
buffer_index += 1;
let tuple = skip_field(
nodes,
list_field,
data,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
node_index = tuple.0;
buffer_index = tuple.1;
}
Struct(struct_fields) => {
node_index += 1;
buffer_index += 1;

// skip for each field
for struct_field in struct_fields {
let tuple = skip_field(
nodes,
struct_field,
data,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
node_index = tuple.0;
buffer_index = tuple.1;
}
}
Dictionary(_, _) => {
node_index += 1;
buffer_index += 2;
}
Union(fields, _field_type_ids, mode) => {
node_index += 1;
buffer_index += 1;

match mode {
UnionMode::Dense => {
buffer_index += 1;
}
UnionMode::Sparse => {}
};

for field in fields {
let tuple = skip_field(
nodes,
field,
data,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;

node_index = tuple.0;
buffer_index = tuple.1;
}
}
Null => {
node_index += 1;
// no buffer increases
}
_ => {
node_index += 1;
buffer_index += 2;
}
};
Ok((node_index, buffer_index))
}

/// Reads the correct number of buffers based on data type and null_count, and creates a
/// primitive array ref
fn create_primitive_array(
Expand Down Expand Up @@ -493,21 +607,37 @@ pub fn read_record_batch(
let mut arrays = vec![];

if let Some(projection) = projection {
let fields = schema.fields();
for &index in projection {
let field = &fields[index];
let triple = create_array(
field_nodes,
field,
buf,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
node_index = triple.1;
buffer_index = triple.2;
arrays.push(triple.0);
// project fields
for (idx, field) in schema.fields().iter().enumerate() {
// Create array for projected field
if projection.contains(&idx) {
let triple = create_array(
field_nodes,
field,
buf,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
node_index = triple.1;
buffer_index = triple.2;
arrays.push(triple.0);
} else {
// Skip field.
// This must be called to advance `node_index` and `buffer_index`.
let tuple = skip_field(
field_nodes,
field,
buf,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
node_index = tuple.0;
buffer_index = tuple.1;
}
}

RecordBatch::try_new(Arc::new(schema.project(projection)?), arrays)
Expand Down Expand Up @@ -1032,7 +1162,7 @@ mod tests {

use flate2::read::GzDecoder;

use crate::datatypes::{ArrowNativeType, Int8Type};
use crate::datatypes::{ArrowNativeType, Float64Type, Int32Type, Int8Type};
use crate::{datatypes, util::integration_util::*};

#[test]
Expand Down Expand Up @@ -1260,6 +1390,169 @@ mod tests {
});
}

fn create_test_projection_schema() -> Schema {
// define field types
let list_data_type =
DataType::List(Box::new(Field::new("item", DataType::Int32, true)));

let fixed_size_list_data_type = DataType::FixedSizeList(
Box::new(Field::new("item", DataType::Int32, false)),
3,
);

let key_type = DataType::Int8;
let value_type = DataType::Utf8;
let dict_data_type =
DataType::Dictionary(Box::new(key_type), Box::new(value_type));

let union_fileds = vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
];
let union_data_type = DataType::Union(union_fileds, vec![0, 1], UnionMode::Dense);

let struct_fields = vec![
Field::new("id", DataType::Int32, false),
Field::new(
"list",
DataType::List(Box::new(Field::new("item", DataType::Int8, true))),
false,
),
];
let struct_data_type = DataType::Struct(struct_fields);

// define schema
Schema::new(vec![
Field::new("f0", DataType::UInt32, false),
Field::new("f1", DataType::Utf8, false),
Field::new("f2", DataType::Boolean, false),
Field::new("f3", union_data_type, true),
Field::new("f4", DataType::Null, true),
Field::new("f5", DataType::Float64, true),
Field::new("f6", list_data_type, false),
Field::new("f7", DataType::FixedSizeBinary(3), true),
Field::new("f8", fixed_size_list_data_type, false),
Field::new("f9", struct_data_type, false),
Field::new("f10", DataType::Boolean, false),
Field::new("f11", dict_data_type, false),
Field::new("f12", DataType::Utf8, false),
])
}

fn create_test_projection_batch_data(schema: &Schema) -> RecordBatch {
// set test data for each column
let array0 = UInt32Array::from(vec![1, 2, 3]);
let array1 = StringArray::from(vec!["foo", "bar", "baz"]);
let array2 = BooleanArray::from(vec![true, false, true]);

let mut union_builder = UnionBuilder::new_dense(3);
union_builder.append::<Int32Type>("a", 1).unwrap();
union_builder.append::<Float64Type>("b", 10.1).unwrap();
union_builder.append_null::<Float64Type>("b").unwrap();
let array3 = union_builder.build().unwrap();

let array4 = NullArray::new(3);
let array5 = Float64Array::from(vec![Some(1.1), None, Some(3.3)]);
let array6_values = vec![
Some(vec![Some(10), Some(10), Some(10)]),
Some(vec![Some(20), Some(20), Some(20)]),
Some(vec![Some(30), Some(30)]),
];
let array6 = ListArray::from_iter_primitive::<Int32Type, _, _>(array6_values);
let array7_values = vec![vec![11, 12, 13], vec![22, 23, 24], vec![33, 34, 35]];
let array7 =
FixedSizeBinaryArray::try_from_iter(array7_values.into_iter()).unwrap();

let array8_values = ArrayData::builder(DataType::Int32)
.len(9)
.add_buffer(Buffer::from_slice_ref(&[
40, 41, 42, 43, 44, 45, 46, 47, 48,
]))
.build()
.unwrap();
let array8_data = ArrayData::builder(schema.field(8).data_type().clone())
.len(3)
.add_child_data(array8_values)
.build()
.unwrap();
let array8 = FixedSizeListArray::from(array8_data);

let array9_id: ArrayRef = Arc::new(Int32Array::from(vec![1001, 1002, 1003]));
let array9_list: ArrayRef =
Arc::new(ListArray::from_iter_primitive::<Int8Type, _, _>(vec![
Some(vec![Some(-10)]),
Some(vec![Some(-20), Some(-20), Some(-20)]),
Some(vec![Some(-30)]),
]));
let array9 = ArrayDataBuilder::new(schema.field(9).data_type().clone())
.add_child_data(array9_id.data().clone())
.add_child_data(array9_list.data().clone())
.len(3)
.build()
.unwrap();
let array9: ArrayRef = Arc::new(StructArray::from(array9));

let array10 = BooleanArray::from(vec![false, false, true]);

let array11_values = StringArray::from(vec!["x", "yy", "zzz"]);
let array11_keys = Int8Array::from_iter_values([1, 1, 2]);
let array11 =
DictionaryArray::<Int8Type>::try_new(&array11_keys, &array11_values).unwrap();

let array12 = StringArray::from(vec!["a", "bb", "ccc"]);

// create record batch
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(array0),
Arc::new(array1),
Arc::new(array2),
Arc::new(array3),
Arc::new(array4),
Arc::new(array5),
Arc::new(array6),
Arc::new(array7),
Arc::new(array8),
Arc::new(array9),
Arc::new(array10),
Arc::new(array11),
Arc::new(array12),
],
)
.unwrap()
}

#[test]
fn test_projection_array_values() {
// define schema
let schema = create_test_projection_schema();

// create record batch with test data
let batch = create_test_projection_batch_data(&schema);

// write record batch in IPC format
let mut buf = Vec::new();
{
let mut writer = ipc::writer::FileWriter::try_new(&mut buf, &schema).unwrap();
writer.write(&batch).unwrap();
writer.finish().unwrap();
}

// read record batch with projection
for index in 0..12 {
let projection = vec![index];
let reader =
FileReader::try_new(std::io::Cursor::new(buf.clone()), Some(projection));
let read_batch = reader.unwrap().next().unwrap().unwrap();
let projected_column = read_batch.column(0);
let expected_column = batch.column(index);

// check the projected column equals the expected column
assert_eq!(projected_column.as_ref(), expected_column.as_ref());
}
}

#[test]
fn test_arrow_single_float_row() {
let schema = Schema::new(vec![
Expand Down

0 comments on commit 7391710

Please sign in to comment.