diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index d0a20daed330..b5331fdc662d 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -1001,7 +1001,7 @@ macro_rules! try_to_type { } macro_rules! dyn_compare_scalar { - // Applies `LEFT OP RIGHT` when `LEFT` is a `DictionaryArray` + // Applies `LEFT OP RIGHT` when `LEFT` is a `PrimitiveArray` ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{ match $LEFT.data_type() { DataType::Int8 => { @@ -2030,6 +2030,277 @@ macro_rules! typed_compares { }}; } +/// Applies $OP to $LEFT and $RIGHT which are two dictionaries which have (the same) key type $KT +macro_rules! typed_dict_cmp { + ($LEFT: expr, $RIGHT: expr, $OP: expr, $KT: tt) => {{ + match ($LEFT.value_type(), $RIGHT.value_type()) { + (DataType::Boolean, DataType::Boolean) => { + cmp_dict_bool::<$KT, _>($LEFT, $RIGHT, $OP) + } + (DataType::Int8, DataType::Int8) => { + cmp_dict::<$KT, Int8Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::Int16, DataType::Int16) => { + cmp_dict::<$KT, Int16Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::Int32, DataType::Int32) => { + cmp_dict::<$KT, Int32Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::Int64, DataType::Int64) => { + cmp_dict::<$KT, Int64Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::UInt8, DataType::UInt8) => { + cmp_dict::<$KT, UInt8Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::UInt16, DataType::UInt16) => { + cmp_dict::<$KT, UInt16Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::UInt32, DataType::UInt32) => { + cmp_dict::<$KT, UInt32Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::UInt64, DataType::UInt64) => { + cmp_dict::<$KT, UInt64Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::Float32, DataType::Float32) => { + cmp_dict::<$KT, Float32Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::Float64, DataType::Float64) => { + cmp_dict::<$KT, Float64Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::Utf8, DataType::Utf8) => { + cmp_dict_utf8::<$KT, i32, _>($LEFT, $RIGHT, $OP) + } + (DataType::LargeUtf8, DataType::LargeUtf8) => { + cmp_dict_utf8::<$KT, i64, _>($LEFT, $RIGHT, $OP) + } + (DataType::Binary, DataType::Binary) => { + cmp_dict_binary::<$KT, i32, _>($LEFT, $RIGHT, $OP) + } + (DataType::LargeBinary, DataType::LargeBinary) => { + cmp_dict_binary::<$KT, i64, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Timestamp(TimeUnit::Nanosecond, _), + DataType::Timestamp(TimeUnit::Nanosecond, _), + ) => { + cmp_dict::<$KT, TimestampNanosecondType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Timestamp(TimeUnit::Microsecond, _), + DataType::Timestamp(TimeUnit::Microsecond, _), + ) => { + cmp_dict::<$KT, TimestampMicrosecondType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Timestamp(TimeUnit::Millisecond, _), + DataType::Timestamp(TimeUnit::Millisecond, _), + ) => { + cmp_dict::<$KT, TimestampMillisecondType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Timestamp(TimeUnit::Second, _), + DataType::Timestamp(TimeUnit::Second, _), + ) => { + cmp_dict::<$KT, TimestampSecondType, _>($LEFT, $RIGHT, $OP) + } + (DataType::Date32, DataType::Date32) => { + cmp_dict::<$KT, Date32Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::Date64, DataType::Date64) => { + cmp_dict::<$KT, Date64Type, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Interval(IntervalUnit::YearMonth), + DataType::Interval(IntervalUnit::YearMonth), + ) => { + cmp_dict::<$KT, IntervalYearMonthType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Interval(IntervalUnit::DayTime), + DataType::Interval(IntervalUnit::DayTime), + ) => { + cmp_dict::<$KT, IntervalDayTimeType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Interval(IntervalUnit::MonthDayNano), + DataType::Interval(IntervalUnit::MonthDayNano), + ) => { + cmp_dict::<$KT, IntervalMonthDayNanoType, _>($LEFT, $RIGHT, $OP) + } + (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( + "Comparing dictionary arrays of value type {} is not yet implemented", + t1 + ))), + (t1, t2) => Err(ArrowError::CastError(format!( + "Cannot compare two dictionary arrays of different value types ({} and {})", + t1, t2 + ))), + } + }}; +} + +macro_rules! typed_dict_compares { + // Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray` + ($LEFT: expr, $RIGHT: expr, $OP: expr) => {{ + match ($LEFT.data_type(), $RIGHT.data_type()) { + (DataType::Dictionary(left_key_type, _), DataType::Dictionary(right_key_type, _))=> { + match (left_key_type.as_ref(), right_key_type.as_ref()) { + (DataType::Int8, DataType::Int8) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP, Int8Type) + } + (DataType::Int16, DataType::Int16) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP, Int16Type) + } + (DataType::Int32, DataType::Int32) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP, Int32Type) + } + (DataType::Int64, DataType::Int64) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP, Int64Type) + } + (DataType::UInt8, DataType::UInt8) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP, UInt8Type) + } + (DataType::UInt16, DataType::UInt16) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP, UInt16Type) + } + (DataType::UInt32, DataType::UInt32) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP, UInt32Type) + } + (DataType::UInt64, DataType::UInt64) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP, UInt64Type) + } + (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( + "Comparing dictionary arrays of type {} is not yet implemented", + t1 + ))), + (t1, t2) => Err(ArrowError::CastError(format!( + "Cannot compare two dictionary arrays of different key types ({} and {})", + t1, t2 + ))), + } + } + (t1, t2) => Err(ArrowError::CastError(format!( + "Cannot compare dictionary array with non-dictionary array ({} and {})", + t1, t2 + ))), + } + }}; +} + +/// Helper function to perform boolean lambda function on values from two dictionary arrays, this +/// version does not attempt to use SIMD explicitly (though the compiler may auto vectorize) +macro_rules! compare_dict_op { + ($left: expr, $right:expr, $op:expr, $value_ty:ty) => {{ + if $left.len() != $right.len() { + return Err(ArrowError::ComputeError( + "Cannot perform comparison operation on arrays of different length" + .to_string(), + )); + } + let left_values = $left.values().as_any().downcast_ref::<$value_ty>().unwrap(); + let right_values = $right + .values() + .as_any() + .downcast_ref::<$value_ty>() + .unwrap(); + + let result = $left + .keys() + .iter() + .zip($right.keys().iter()) + .map(|(left_key, right_key)| { + if let (Some(left_k), Some(right_k)) = (left_key, right_key) { + let left_key = left_k.to_usize().expect("Dictionary index not usize"); + let right_key = + right_k.to_usize().expect("Dictionary index not usize"); + unsafe { + let left_value = left_values.value_unchecked(left_key); + let right_value = right_values.value_unchecked(right_key); + Some($op(left_value, right_value)) + } + } else { + None + } + }) + .collect(); + + Ok(result) + }}; +} + +/// Perform given operation on two `DictionaryArray`s. +/// Returns an error if the two arrays have different value type +pub fn cmp_dict( + left: &DictionaryArray, + right: &DictionaryArray, + op: F, +) -> Result +where + K: ArrowNumericType, + T: ArrowNumericType, + F: Fn(T::Native, T::Native) -> bool, +{ + compare_dict_op!(left, right, op, PrimitiveArray) +} + +/// Perform the given operation on two `DictionaryArray`s which value type is +/// `DataType::Boolean`. +pub fn cmp_dict_bool( + left: &DictionaryArray, + right: &DictionaryArray, + op: F, +) -> Result +where + K: ArrowNumericType, + F: Fn(bool, bool) -> bool, +{ + compare_dict_op!(left, right, op, BooleanArray) +} + +/// Perform the given operation on two `DictionaryArray`s which value type is +/// `DataType::Utf8` or `DataType::LargeUtf8`. +pub fn cmp_dict_utf8( + left: &DictionaryArray, + right: &DictionaryArray, + op: F, +) -> Result +where + K: ArrowNumericType, + F: Fn(&str, &str) -> bool, +{ + compare_dict_op!(left, right, op, GenericStringArray) +} + +/// Perform the given operation on two `DictionaryArray`s which value type is +/// `DataType::Binary` or `DataType::LargeBinary`. +pub fn cmp_dict_binary( + left: &DictionaryArray, + right: &DictionaryArray, + op: F, +) -> Result +where + K: ArrowNumericType, + F: Fn(&[u8], &[u8]) -> bool, +{ + compare_dict_op!(left, right, op, GenericBinaryArray) +} + /// Perform `left == right` operation on two (dynamic) [`Array`]s. /// /// Only when two arrays are of the same type the comparison will happen otherwise it will err @@ -2045,7 +2316,12 @@ macro_rules! typed_compares { /// assert_eq!(BooleanArray::from(vec![Some(true), None, Some(false)]), result); /// ``` pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { - typed_compares!(left, right, eq_bool, eq, eq_utf8, eq_binary) + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_compares!(left, right, |a, b| a == b) + } + _ => typed_compares!(left, right, eq_bool, eq, eq_utf8, eq_binary), + } } /// Perform `left != right` operation on two (dynamic) [`Array`]s. @@ -2423,6 +2699,7 @@ mod tests { use super::*; use crate::datatypes::Int8Type; + use crate::datatypes::ToByteSlice; use crate::{array::Int32Array, array::Int64Array, datatypes::Field}; /// Evaluate `KERNEL` with two vectors as inputs and assert against the expected output. @@ -4374,4 +4651,192 @@ mod tests { BooleanArray::from(vec![Some(true), Some(false), Some(true)]) ); } + + fn get_dict_arraydata( + keys: Buffer, + key_type: DataType, + value_data: ArrayData, + ) -> ArrayData { + let value_type = value_data.data_type().clone(); + let dict_data_type = + DataType::Dictionary(Box::new(key_type), Box::new(value_type)); + ArrayData::builder(dict_data_type) + .len(3) + .add_buffer(keys) + .add_child_data(value_data) + .build() + .unwrap() + } + + #[test] + fn test_eq_dyn_dictionary_i8_array() { + let key_type = DataType::Int8; + // Construct a value array + let value_data = ArrayData::builder(DataType::Int8) + .len(8) + .add_buffer(Buffer::from( + &[10_i8, 11, 12, 13, 14, 15, 16, 17].to_byte_slice(), + )) + .build() + .unwrap(); + + let keys1 = Buffer::from(&[2_i8, 3, 4].to_byte_slice()); + let keys2 = Buffer::from(&[2_i8, 4, 4].to_byte_slice()); + let dict_array1: DictionaryArray = Int8DictionaryArray::from( + get_dict_arraydata(keys1, key_type.clone(), value_data.clone()), + ); + let dict_array2: DictionaryArray = + Int8DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data)); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true])); + } + + #[test] + fn test_eq_dyn_dictionary_u64_array() { + let key_type = DataType::UInt64; + // Construct a value array + let value_data = ArrayData::builder(DataType::UInt64) + .len(8) + .add_buffer(Buffer::from( + &[10_u64, 11, 12, 13, 14, 15, 16, 17].to_byte_slice(), + )) + .build() + .unwrap(); + + let keys1 = Buffer::from(&[1_u64, 3, 4].to_byte_slice()); + let keys2 = Buffer::from(&[2_u64, 3, 5].to_byte_slice()); + let dict_array1: DictionaryArray = UInt64DictionaryArray::from( + get_dict_arraydata(keys1, key_type.clone(), value_data.clone()), + ); + let dict_array2: DictionaryArray = + UInt64DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data)); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![false, true, false]) + ); + } + + #[test] + fn test_eq_dyn_dictionary_utf8_array() { + let test1 = vec!["a", "a", "b", "c"]; + let test2 = vec!["a", "b", "b", "c"]; + + let dict_array1: DictionaryArray = test1 + .iter() + .map(|&x| if x == "b" { None } else { Some(x) }) + .collect(); + let dict_array2: DictionaryArray = test2 + .iter() + .map(|&x| if x == "b" { None } else { Some(x) }) + .collect(); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, None, Some(true)]) + ); + } + + #[test] + fn test_eq_dyn_dictionary_binary_array() { + let key_type = DataType::UInt64; + + // Construct a value array + let values: [u8; 12] = [ + b'h', b'e', b'l', b'l', b'o', b'p', b'a', b'r', b'q', b'u', b'e', b't', + ]; + let offsets: [i32; 4] = [0, 5, 5, 12]; + + // Array data: ["hello", "", "parquet"] + let value_data = ArrayData::builder(DataType::Binary) + .len(3) + .add_buffer(Buffer::from_slice_ref(&offsets)) + .add_buffer(Buffer::from_slice_ref(&values)) + .build() + .unwrap(); + + let keys1 = Buffer::from(&[0_u64, 1, 2].to_byte_slice()); + let keys2 = Buffer::from(&[0_u64, 2, 1].to_byte_slice()); + let dict_array1: DictionaryArray = UInt64DictionaryArray::from( + get_dict_arraydata(keys1, key_type.clone(), value_data.clone()), + ); + let dict_array2: DictionaryArray = + UInt64DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data)); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![true, false, false]) + ); + } + + #[test] + fn test_eq_dyn_dictionary_interval_array() { + let key_type = DataType::UInt64; + + let value_array = IntervalDayTimeArray::from(vec![1, 6, 10, 2, 3, 5]); + let value_data = value_array.data().clone(); + + let keys1 = Buffer::from(&[1_u64, 0, 3].to_byte_slice()); + let keys2 = Buffer::from(&[2_u64, 0, 3].to_byte_slice()); + let dict_array1: DictionaryArray = UInt64DictionaryArray::from( + get_dict_arraydata(keys1, key_type.clone(), value_data.clone()), + ); + let dict_array2: DictionaryArray = + UInt64DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data)); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true])); + } + + #[test] + fn test_eq_dyn_dictionary_date_array() { + let key_type = DataType::UInt64; + + let value_array = Date32Array::from(vec![1, 6, 10, 2, 3, 5]); + let value_data = value_array.data().clone(); + + let keys1 = Buffer::from(&[1_u64, 0, 3].to_byte_slice()); + let keys2 = Buffer::from(&[2_u64, 0, 3].to_byte_slice()); + let dict_array1: DictionaryArray = UInt64DictionaryArray::from( + get_dict_arraydata(keys1, key_type.clone(), value_data.clone()), + ); + let dict_array2: DictionaryArray = + UInt64DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data)); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true])); + } + + #[test] + fn test_eq_dyn_dictionary_bool_array() { + let key_type = DataType::UInt64; + + let value_array = BooleanArray::from(vec![true, false]); + let value_data = value_array.data().clone(); + + let keys1 = Buffer::from(&[1_u64, 1, 1].to_byte_slice()); + let keys2 = Buffer::from(&[0_u64, 1, 0].to_byte_slice()); + let dict_array1: DictionaryArray = UInt64DictionaryArray::from( + get_dict_arraydata(keys1, key_type.clone(), value_data.clone()), + ); + let dict_array2: DictionaryArray = + UInt64DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data)); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![false, true, false]) + ); + } }