Skip to content

Commit

Permalink
Replace macro with TypedDictionaryArray (#2514)
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya authored Aug 19, 2022
1 parent a90fc64 commit c20bb8a
Showing 1 changed file with 57 additions and 45 deletions.
102 changes: 57 additions & 45 deletions arrow/src/compute/kernels/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2175,49 +2175,39 @@ macro_rules! typed_dict_compares {

/// Helper function to perform boolean lambda function on values from two dictionary arrays, this
/// version does not attempt to use SIMD explicitly (though the compiler may auto vectorize)
macro_rules! compare_dict_op {
($left: expr, $right:expr, $op:expr, $value_ty:ty) => {{
if $left.len() != $right.len() {
return Err(ArrowError::ComputeError(
"Cannot perform comparison operation on arrays of different length"
.to_string(),
));
}

// Safety justification: Since the inputs are valid Arrow arrays, all values are
// valid indexes into the dictionary (which is verified during construction)

let left_iter = unsafe {
$left
.values()
.as_any()
.downcast_ref::<$value_ty>()
.unwrap()
.take_iter_unchecked($left.keys_iter())
};
fn compare_dict_op<'a, K, V, F>(
left: TypedDictionaryArray<'a, K, V>,
right: TypedDictionaryArray<'a, K, V>,
op: F,
) -> Result<BooleanArray>
where
K: ArrowNumericType,
V: Sync + Send,
&'a V: ArrayAccessor,
F: Fn(<&V as ArrayAccessor>::Item, <&V as ArrayAccessor>::Item) -> bool,
{
if left.len() != right.len() {
return Err(ArrowError::ComputeError(
"Cannot perform comparison operation on arrays of different length"
.to_string(),
));
}

let right_iter = unsafe {
$right
.values()
.as_any()
.downcast_ref::<$value_ty>()
.unwrap()
.take_iter_unchecked($right.keys_iter())
};
let left_iter = left.into_iter();
let right_iter = right.into_iter();

let result = left_iter
.zip(right_iter)
.map(|(left_value, right_value)| {
if let (Some(left), Some(right)) = (left_value, right_value) {
Some($op(left, right))
} else {
None
}
})
.collect();
let result = left_iter
.zip(right_iter)
.map(|(left_value, right_value)| {
if let (Some(left), Some(right)) = (left_value, right_value) {
Some(op(left, right))
} else {
None
}
})
.collect();

Ok(result)
}};
Ok(result)
}

/// Perform given operation on two `DictionaryArray`s.
Expand All @@ -2229,10 +2219,14 @@ pub fn cmp_dict<K, T, F>(
) -> Result<BooleanArray>
where
K: ArrowNumericType,
T: ArrowNumericType,
T: ArrowNumericType + Sync + Send,
F: Fn(T::Native, T::Native) -> bool,
{
compare_dict_op!(left, right, op, PrimitiveArray<T>)
compare_dict_op(
left.downcast_dict::<PrimitiveArray<T>>().unwrap(),
right.downcast_dict::<PrimitiveArray<T>>().unwrap(),
op,
)
}

/// Perform the given operation on two `DictionaryArray`s which value type is
Expand All @@ -2246,7 +2240,11 @@ where
K: ArrowNumericType,
F: Fn(bool, bool) -> bool,
{
compare_dict_op!(left, right, op, BooleanArray)
compare_dict_op(
left.downcast_dict::<BooleanArray>().unwrap(),
right.downcast_dict::<BooleanArray>().unwrap(),
op,
)
}

/// Perform the given operation on two `DictionaryArray`s which value type is
Expand All @@ -2260,7 +2258,14 @@ where
K: ArrowNumericType,
F: Fn(&str, &str) -> bool,
{
compare_dict_op!(left, right, op, GenericStringArray<OffsetSize>)
compare_dict_op(
left.downcast_dict::<GenericStringArray<OffsetSize>>()
.unwrap(),
right
.downcast_dict::<GenericStringArray<OffsetSize>>()
.unwrap(),
op,
)
}

/// Perform the given operation on two `DictionaryArray`s which value type is
Expand All @@ -2274,7 +2279,14 @@ where
K: ArrowNumericType,
F: Fn(&[u8], &[u8]) -> bool,
{
compare_dict_op!(left, right, op, GenericBinaryArray<OffsetSize>)
compare_dict_op(
left.downcast_dict::<GenericBinaryArray<OffsetSize>>()
.unwrap(),
right
.downcast_dict::<GenericBinaryArray<OffsetSize>>()
.unwrap(),
op,
)
}

/// Perform `left == right` operation on two (dynamic) [`Array`]s.
Expand Down

0 comments on commit c20bb8a

Please sign in to comment.