-
Notifications
You must be signed in to change notification settings - Fork 866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Specialize filter for structs and sparse unions #6304
Changes from 4 commits
176c785
38ea4ae
568c5e3
f8c5b1a
d643a43
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -166,10 +166,29 @@ pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray { | |
/// assert_eq!(c, &Int32Array::from(vec![5, 8])); | ||
/// ``` | ||
pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> { | ||
let predicate = FilterBuilder::new(predicate).build(); | ||
let mut filter_builder = FilterBuilder::new(predicate); | ||
|
||
if multiple_arrays(values.data_type()) { | ||
// Only optimize if filtering more than one array | ||
// Otherwise, the overhead of optimization can be more than the benefit | ||
filter_builder = filter_builder.optimize(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens if you call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the rationale is that "optimizing" the filter requires looking at the BooleanArray which itself requires non trivial time, so for certain operations the overhead of figuring out a better filter strategy takes more time than actually running it This is basically the same algorithm used in Users who want to always optimize can use a I made a PR to try and clarify this in the docs #6317 |
||
} | ||
|
||
let predicate = filter_builder.build(); | ||
|
||
filter_array(values, &predicate) | ||
} | ||
|
||
fn multiple_arrays(data_type: &DataType) -> bool { | ||
match data_type { | ||
DataType::Struct(fields) => { | ||
fields.len() > 1 || fields.len() == 1 && multiple_arrays(fields[0].data_type()) | ||
} | ||
DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(), | ||
_ => false, | ||
} | ||
} | ||
|
||
/// Returns a new [RecordBatch] with arrays containing only values matching the filter. | ||
pub fn filter_record_batch( | ||
record_batch: &RecordBatch, | ||
|
@@ -358,6 +377,12 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<Array | |
values => Ok(Arc::new(filter_dict(values, predicate))), | ||
t => unimplemented!("Filter not supported for dictionary type {:?}", t) | ||
} | ||
DataType::Struct(_) => { | ||
Ok(Arc::new(filter_struct(values.as_struct(), predicate)?)) | ||
} | ||
DataType::Union(_, UnionMode::Sparse) => { | ||
Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?)) | ||
} | ||
_ => { | ||
let data = values.to_data(); | ||
// fallback to using MutableArrayData | ||
|
@@ -782,6 +807,49 @@ where | |
DictionaryArray::from(unsafe { builder.build_unchecked() }) | ||
} | ||
|
||
/// `filter` implementation for structs | ||
fn filter_struct( | ||
array: &StructArray, | ||
predicate: &FilterPredicate, | ||
) -> Result<StructArray, ArrowError> { | ||
let columns = array | ||
.columns() | ||
.iter() | ||
.map(|column| filter_array(column, predicate)) | ||
.collect::<Result<_, _>>()?; | ||
|
||
let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { | ||
let buffer = BooleanBuffer::new(nulls, 0, predicate.count); | ||
|
||
Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) }) | ||
} else { | ||
None | ||
}; | ||
|
||
Ok(unsafe { StructArray::new_unchecked(array.fields().clone(), columns, nulls) }) | ||
} | ||
|
||
/// `filter` implementation for sparse unions | ||
fn filter_sparse_union( | ||
array: &UnionArray, | ||
predicate: &FilterPredicate, | ||
) -> Result<UnionArray, ArrowError> { | ||
let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else { | ||
unreachable!() | ||
}; | ||
|
||
let type_ids = filter_primitive(&Int8Array::new(array.type_ids().clone(), None), predicate); | ||
|
||
let children = fields | ||
.iter() | ||
.map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate)) | ||
.collect::<Result<_, _>>()?; | ||
|
||
Ok(unsafe { | ||
UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children) | ||
}) | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use arrow_array::builder::*; | ||
|
@@ -1871,4 +1939,75 @@ mod tests { | |
} | ||
} | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Union tests already exists There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (I double checked and they are immediate above this -- fn test_filter_union_array_sparse() {
... 👍 |
||
#[test] | ||
fn test_filter_struct() { | ||
let predicate = BooleanArray::from(vec![true, false, true, false]); | ||
|
||
let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"])); | ||
let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"])); | ||
|
||
let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8])); | ||
let b_filtered = Arc::new(Int32Array::from(vec![5, 7])); | ||
|
||
let null_mask = NullBuffer::from(vec![true, false, false, true]); | ||
let null_mask_filtered = NullBuffer::from(vec![true, false]); | ||
|
||
let a_field = Field::new("a", DataType::Utf8, false); | ||
let b_field = Field::new("b", DataType::Int32, false); | ||
|
||
let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None); | ||
let expected = | ||
StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None); | ||
|
||
let result = filter(&array, &predicate).unwrap(); | ||
|
||
assert_eq!(result.to_data(), expected.to_data()); | ||
|
||
let array = StructArray::new( | ||
vec![a_field.clone()].into(), | ||
vec![a.clone()], | ||
Some(null_mask.clone()), | ||
); | ||
let expected = StructArray::new( | ||
vec![a_field.clone()].into(), | ||
vec![a_filtered.clone()], | ||
Some(null_mask_filtered.clone()), | ||
); | ||
|
||
let result = filter(&array, &predicate).unwrap(); | ||
|
||
assert_eq!(result.to_data(), expected.to_data()); | ||
|
||
let array = StructArray::new( | ||
vec![a_field.clone(), b_field.clone()].into(), | ||
vec![a.clone(), b.clone()], | ||
None, | ||
); | ||
let expected = StructArray::new( | ||
vec![a_field.clone(), b_field.clone()].into(), | ||
vec![a_filtered.clone(), b_filtered.clone()], | ||
None, | ||
); | ||
|
||
let result = filter(&array, &predicate).unwrap(); | ||
|
||
assert_eq!(result.to_data(), expected.to_data()); | ||
|
||
let array = StructArray::new( | ||
vec![a_field.clone(), b_field.clone()].into(), | ||
vec![a.clone(), b.clone()], | ||
Some(null_mask.clone()), | ||
); | ||
|
||
let expected = StructArray::new( | ||
vec![a_field.clone(), b_field.clone()].into(), | ||
vec![a_filtered.clone(), b_filtered.clone()], | ||
Some(null_mask_filtered.clone()), | ||
); | ||
|
||
let result = filter(&array, &predicate).unwrap(); | ||
|
||
assert_eq!(result.to_data(), expected.to_data()); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍