diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs index 7b4b1d6eca4b..cda179b78c2e 100644 --- a/arrow-array/src/cast.rs +++ b/arrow-array/src/cast.rs @@ -815,6 +815,14 @@ pub trait AsArray: private::Sealed { self.as_struct_opt().expect("struct array") } + /// Downcast this to a [`UnionArray`] returning `None` if not possible + fn as_union_opt(&self) -> Option<&UnionArray>; + + /// Downcast this to a [`UnionArray`] panicking if not possible + fn as_union(&self) -> &UnionArray { + self.as_union_opt().expect("union array") + } + /// Downcast this to a [`GenericListArray`] returning `None` if not possible fn as_list_opt(&self) -> Option<&GenericListArray>; @@ -888,6 +896,10 @@ impl AsArray for dyn Array + '_ { self.as_any().downcast_ref() } + fn as_union_opt(&self) -> Option<&UnionArray> { + self.as_any().downcast_ref() + } + fn as_list_opt(&self) -> Option<&GenericListArray> { self.as_any().downcast_ref() } @@ -939,6 +951,10 @@ impl AsArray for ArrayRef { self.as_ref().as_struct_opt() } + fn as_union_opt(&self) -> Option<&UnionArray> { + self.as_any().downcast_ref() + } + fn as_list_opt(&self) -> Option<&GenericListArray> { self.as_ref().as_list_opt() } diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index c51f44a977f6..e07b03d1f276 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -169,10 +169,29 @@ pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray { /// assert_eq!(c, &Int32Array::from(vec![5, 8])); /// ``` pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result { - let predicate = FilterBuilder::new(predicate).build(); + let mut filter_builder = FilterBuilder::new(predicate); + + if multiple_arrays(values.data_type()) { + // Only optimize if filtering more than one array + // Otherwise, the overhead of optimization can be more than the benefit + filter_builder = filter_builder.optimize(); + } + + let predicate = filter_builder.build(); + filter_array(values, &predicate) } +fn multiple_arrays(data_type: &DataType) -> bool { + match data_type { + DataType::Struct(fields) => { + fields.len() > 1 || fields.len() == 1 && multiple_arrays(fields[0].data_type()) + } + DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(), + _ => false, + } +} + /// Returns a filtered [RecordBatch] where the corresponding elements of /// `predicate` are true. /// @@ -365,6 +384,12 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result Ok(Arc::new(filter_dict(values, predicate))), t => unimplemented!("Filter not supported for dictionary type {:?}", t) } + DataType::Struct(_) => { + Ok(Arc::new(filter_struct(values.as_struct(), predicate)?)) + } + DataType::Union(_, UnionMode::Sparse) => { + Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?)) + } _ => { let data = values.to_data(); // fallback to using MutableArrayData @@ -789,6 +814,49 @@ where DictionaryArray::from(unsafe { builder.build_unchecked() }) } +/// `filter` implementation for structs +fn filter_struct( + array: &StructArray, + predicate: &FilterPredicate, +) -> Result { + let columns = array + .columns() + .iter() + .map(|column| filter_array(column, predicate)) + .collect::>()?; + + let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { + let buffer = BooleanBuffer::new(nulls, 0, predicate.count); + + Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) }) + } else { + None + }; + + Ok(unsafe { StructArray::new_unchecked(array.fields().clone(), columns, nulls) }) +} + +/// `filter` implementation for sparse unions +fn filter_sparse_union( + array: &UnionArray, + predicate: &FilterPredicate, +) -> Result { + let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else { + unreachable!() + }; + + let type_ids = filter_primitive(&Int8Array::new(array.type_ids().clone(), None), predicate); + + let children = fields + .iter() + .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate)) + .collect::>()?; + + Ok(unsafe { + UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children) + }) +} + #[cfg(test)] mod tests { use arrow_array::builder::*; @@ -1878,4 +1946,75 @@ mod tests { } } } + + #[test] + fn test_filter_struct() { + let predicate = BooleanArray::from(vec![true, false, true, false]); + + let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"])); + let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"])); + + let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8])); + let b_filtered = Arc::new(Int32Array::from(vec![5, 7])); + + let null_mask = NullBuffer::from(vec![true, false, false, true]); + let null_mask_filtered = NullBuffer::from(vec![true, false]); + + let a_field = Field::new("a", DataType::Utf8, false); + let b_field = Field::new("b", DataType::Int32, false); + + let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None); + let expected = + StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None); + + let result = filter(&array, &predicate).unwrap(); + + assert_eq!(result.to_data(), expected.to_data()); + + let array = StructArray::new( + vec![a_field.clone()].into(), + vec![a.clone()], + Some(null_mask.clone()), + ); + let expected = StructArray::new( + vec![a_field.clone()].into(), + vec![a_filtered.clone()], + Some(null_mask_filtered.clone()), + ); + + let result = filter(&array, &predicate).unwrap(); + + assert_eq!(result.to_data(), expected.to_data()); + + let array = StructArray::new( + vec![a_field.clone(), b_field.clone()].into(), + vec![a.clone(), b.clone()], + None, + ); + let expected = StructArray::new( + vec![a_field.clone(), b_field.clone()].into(), + vec![a_filtered.clone(), b_filtered.clone()], + None, + ); + + let result = filter(&array, &predicate).unwrap(); + + assert_eq!(result.to_data(), expected.to_data()); + + let array = StructArray::new( + vec![a_field.clone(), b_field.clone()].into(), + vec![a.clone(), b.clone()], + Some(null_mask.clone()), + ); + + let expected = StructArray::new( + vec![a_field.clone(), b_field.clone()].into(), + vec![a_filtered.clone(), b_filtered.clone()], + Some(null_mask_filtered.clone()), + ); + + let result = filter(&array, &predicate).unwrap(); + + assert_eq!(result.to_data(), expected.to_data()); + } }