From 727aba14bad843a8a2598c3cfae641a5f86260a0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 30 Jan 2022 10:18:21 -0800 Subject: [PATCH 01/10] Implement DictionaryArray support in eq_dyn --- arrow/src/compute/kernels/comparison.rs | 144 +++++++++++++++++++++++- 1 file changed, 142 insertions(+), 2 deletions(-) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 12d214b60586..ed1a74d8fa65 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -1001,7 +1001,7 @@ macro_rules! try_to_type { } macro_rules! dyn_compare_scalar { - // Applies `LEFT OP RIGHT` when `LEFT` is a `DictionaryArray` + // Applies `LEFT OP RIGHT` when `LEFT` is a `PrimitiveArray` ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{ match $LEFT.data_type() { DataType::Int8 => { @@ -2030,12 +2030,112 @@ macro_rules! typed_compares { }}; } +macro_rules! typed_dict_cmp { + ($LEFT: expr, $RIGHT: expr, $OP_PRIM: ident, $KT: tt) => {{ + match ($LEFT.value_type(), $RIGHT.value_type()) { + (DataType::Int8, DataType::Int8) => { + $OP_PRIM::<$KT, Int8Type>($LEFT, $RIGHT) + } + (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( + "Comparing dictionary arrays of value type {} is not yet implemented", + t1 + ))), + (t1, t2) => Err(ArrowError::CastError(format!( + "Cannot compare two dictionary arrays of different value types ({} and {})", + t1, t2 + ))), + } + }}; +} + +macro_rules! typed_dict_compares { + // Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray` + ($LEFT: expr, $RIGHT: expr, $OP_PRIM: ident) => {{ + match ($LEFT.data_type(), $RIGHT.data_type()) { + (DataType::Dictionary(left_key_type, _), DataType::Dictionary(right_key_type, _))=> { + match (left_key_type.as_ref(), right_key_type.as_ref()) { + (DataType::Int8, DataType::Int8) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP_PRIM, Int8Type) + } + (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( + "Comparing dictionary arrays of type {} is not yet implemented", + t1 + ))), + (t1, t2) => Err(ArrowError::CastError(format!( + "Cannot compare two dictionary arrays of different key types ({} and {})", + t1, t2 + ))), + } + } + (t1, t2) => Err(ArrowError::CastError(format!( + "Cannot compare dictionary array with non-dictionary array ({} and {})", + t1, t2 + ))), + } + }}; +} + +/// Perform `left == right` operation on two `DictionaryArray`s. +/// Only when two arrays are of the same type the comparison will happen otherwise it will err +/// with a casting error. +pub fn eq_dict( + left: &DictionaryArray, + right: &DictionaryArray, +) -> Result +where + K: ArrowNumericType, + T: ArrowNumericType, +{ + assert_eq!(left.keys().len(), right.keys().len()); + + let left_values = left + .values() + .as_any() + .downcast_ref::>() + .unwrap(); + let right_values = right + .values() + .as_any() + .downcast_ref::>() + .unwrap(); + + let mut result = vec![]; + + for idx in 0..left.keys().len() { + unsafe { + let left_key = left + .keys() + .value_unchecked(idx) + .to_usize() + .expect("Dictionary index not usize"); + let left_value = left_values.value_unchecked(left_key); + let right_key = right + .keys() + .value_unchecked(idx) + .to_usize() + .expect("Dictionary index not usize"); + let right_value = right_values.value_unchecked(right_key); + + result.push(left_value == right_value); + } + } + + Ok(BooleanArray::from(result)) +} + /// Perform `left == right` operation on two (dynamic) [`Array`]s. /// /// Only when two arrays are of the same type the comparison will happen otherwise it will err /// with a casting error. pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { - typed_compares!(left, right, eq_bool, eq, eq_utf8, eq_binary) + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_compares!(left, right, eq_dict) + } + _ => typed_compares!(left, right, eq_bool, eq, eq_utf8, eq_binary), + } } /// Perform `left != right` operation on two (dynamic) [`Array`]s. @@ -2359,6 +2459,7 @@ mod tests { use super::*; use crate::datatypes::Int8Type; + use crate::datatypes::ToByteSlice; use crate::{array::Int32Array, array::Int64Array, datatypes::Field}; /// Evaluate `KERNEL` with two vectors as inputs and assert against the expected output. @@ -4296,4 +4397,43 @@ mod tests { BooleanArray::from(vec![Some(true), Some(false), Some(true)]) ); } + + fn get_dictionary_array(keys: Buffer) -> DictionaryArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int8) + .len(8) + .add_buffer(Buffer::from( + &[10_i8, 11, 12, 13, 14, 15, 16, 17].to_byte_slice(), + )) + .build() + .unwrap(); + + // Construct a dictionary array from the above two + let key_type = DataType::Int8; + let value_type = DataType::Int8; + let dict_data_type = + DataType::Dictionary(Box::new(key_type), Box::new(value_type)); + let dict_data = ArrayData::builder(dict_data_type.clone()) + .len(3) + .add_buffer(keys.clone()) + .add_child_data(value_data.clone()) + .build() + .unwrap(); + Int8DictionaryArray::from(dict_data) + } + + #[test] + fn test_eq_dyn_dictionary_array() { + let keys1 = Buffer::from(&[2_i8, 3, 4].to_byte_slice()); + let keys2 = Buffer::from(&[2_i8, 4, 4].to_byte_slice()); + let dict_array1 = get_dictionary_array(keys1); + let dict_array2 = get_dictionary_array(keys2); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), Some(false), Some(true)]) + ); + } } From eff3e100b301ba8d82e2485137061a97cce6f0ac Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 Feb 2022 15:35:00 -0800 Subject: [PATCH 02/10] For review comment: make eq_dict as generic and rename to cmp_dict. Remove unsafeness. --- arrow/src/compute/kernels/comparison.rs | 67 ++++++++++++++----------- 1 file changed, 39 insertions(+), 28 deletions(-) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index ed1a74d8fa65..52562cbca5f8 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -2031,10 +2031,10 @@ macro_rules! typed_compares { } macro_rules! typed_dict_cmp { - ($LEFT: expr, $RIGHT: expr, $OP_PRIM: ident, $KT: tt) => {{ + ($LEFT: expr, $RIGHT: expr, $OP_PRIM: expr, $KT: tt) => {{ match ($LEFT.value_type(), $RIGHT.value_type()) { (DataType::Int8, DataType::Int8) => { - $OP_PRIM::<$KT, Int8Type>($LEFT, $RIGHT) + cmp_dict::<$KT, Int8Type, _>($LEFT, $RIGHT, $OP_PRIM) } (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( "Comparing dictionary arrays of value type {} is not yet implemented", @@ -2050,7 +2050,7 @@ macro_rules! typed_dict_cmp { macro_rules! typed_dict_compares { // Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray` - ($LEFT: expr, $RIGHT: expr, $OP_PRIM: ident) => {{ + ($LEFT: expr, $RIGHT: expr, $OP_PRIM: expr) => {{ match ($LEFT.data_type(), $RIGHT.data_type()) { (DataType::Dictionary(left_key_type, _), DataType::Dictionary(right_key_type, _))=> { match (left_key_type.as_ref(), right_key_type.as_ref()) { @@ -2077,18 +2077,25 @@ macro_rules! typed_dict_compares { }}; } -/// Perform `left == right` operation on two `DictionaryArray`s. +/// Perform given operation on two `DictionaryArray`s. /// Only when two arrays are of the same type the comparison will happen otherwise it will err /// with a casting error. -pub fn eq_dict( +pub fn cmp_dict( left: &DictionaryArray, right: &DictionaryArray, + op: F, ) -> Result where K: ArrowNumericType, T: ArrowNumericType, + F: Fn(T::Native, T::Native) -> bool, { - assert_eq!(left.keys().len(), right.keys().len()); + if left.len() != right.len() { + return Err(ArrowError::ComputeError( + "Cannot perform comparison operation on arrays of different length" + .to_string(), + )); + } let left_values = left .values() @@ -2101,28 +2108,32 @@ where .downcast_ref::>() .unwrap(); - let mut result = vec![]; - - for idx in 0..left.keys().len() { - unsafe { - let left_key = left - .keys() - .value_unchecked(idx) - .to_usize() - .expect("Dictionary index not usize"); - let left_value = left_values.value_unchecked(left_key); - let right_key = right - .keys() - .value_unchecked(idx) - .to_usize() - .expect("Dictionary index not usize"); - let right_value = right_values.value_unchecked(right_key); - - result.push(left_value == right_value); - } - } + let result = left + .keys() + .iter() + .zip(right.keys().iter()) + .map(|(left_key, right_key)| { + if left_key.is_none() || right_key.is_none() { + None + } else { + let left_key = left_key + .unwrap() + .to_usize() + .expect("Dictionary index not usize"); + let right_key = right_key + .unwrap() + .to_usize() + .expect("Dictionary index not usize"); + unsafe { + let left_value = left_values.value_unchecked(left_key); + let right_value = right_values.value_unchecked(right_key); + Some(op(left_value, right_value)) + } + } + }) + .collect(); - Ok(BooleanArray::from(result)) + Ok(result) } /// Perform `left == right` operation on two (dynamic) [`Array`]s. @@ -2132,7 +2143,7 @@ where pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { - typed_dict_compares!(left, right, eq_dict) + typed_dict_compares!(left, right, |a, b| a == b) } _ => typed_compares!(left, right, eq_bool, eq, eq_utf8, eq_binary), } From e8399e0dab38f6fe7b233d153b36e5b776b35f63 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 Feb 2022 16:10:38 -0800 Subject: [PATCH 03/10] Other integer types --- arrow/src/compute/kernels/comparison.rs | 129 ++++++++++++++++++++---- 1 file changed, 110 insertions(+), 19 deletions(-) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 52562cbca5f8..74be8e6ecada 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -2036,6 +2036,27 @@ macro_rules! typed_dict_cmp { (DataType::Int8, DataType::Int8) => { cmp_dict::<$KT, Int8Type, _>($LEFT, $RIGHT, $OP_PRIM) } + (DataType::Int16, DataType::Int16) => { + cmp_dict::<$KT, Int16Type, _>($LEFT, $RIGHT, $OP_PRIM) + } + (DataType::Int32, DataType::Int32) => { + cmp_dict::<$KT, Int32Type, _>($LEFT, $RIGHT, $OP_PRIM) + } + (DataType::Int64, DataType::Int64) => { + cmp_dict::<$KT, Int64Type, _>($LEFT, $RIGHT, $OP_PRIM) + } + (DataType::UInt8, DataType::UInt8) => { + cmp_dict::<$KT, UInt8Type, _>($LEFT, $RIGHT, $OP_PRIM) + } + (DataType::UInt16, DataType::UInt16) => { + cmp_dict::<$KT, UInt16Type, _>($LEFT, $RIGHT, $OP_PRIM) + } + (DataType::UInt32, DataType::UInt32) => { + cmp_dict::<$KT, UInt32Type, _>($LEFT, $RIGHT, $OP_PRIM) + } + (DataType::UInt64, DataType::UInt64) => { + cmp_dict::<$KT, UInt64Type, _>($LEFT, $RIGHT, $OP_PRIM) + } (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( "Comparing dictionary arrays of value type {} is not yet implemented", t1 @@ -2059,6 +2080,41 @@ macro_rules! typed_dict_compares { let right = as_dictionary_array::($RIGHT); typed_dict_cmp!(left, right, $OP_PRIM, Int8Type) } + (DataType::Int16, DataType::Int16) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP_PRIM, Int16Type) + } + (DataType::Int32, DataType::Int32) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP_PRIM, Int32Type) + } + (DataType::Int64, DataType::Int64) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP_PRIM, Int64Type) + } + (DataType::UInt8, DataType::UInt8) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP_PRIM, UInt8Type) + } + (DataType::UInt16, DataType::UInt16) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP_PRIM, UInt16Type) + } + (DataType::UInt32, DataType::UInt32) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP_PRIM, UInt32Type) + } + (DataType::UInt64, DataType::UInt64) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP_PRIM, UInt64Type) + } (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( "Comparing dictionary arrays of type {} is not yet implemented", t1 @@ -4409,36 +4465,42 @@ mod tests { ); } - fn get_dictionary_array(keys: Buffer) -> DictionaryArray { - // Construct a value array - let value_data = ArrayData::builder(DataType::Int8) - .len(8) - .add_buffer(Buffer::from( - &[10_i8, 11, 12, 13, 14, 15, 16, 17].to_byte_slice(), - )) - .build() - .unwrap(); - - // Construct a dictionary array from the above two - let key_type = DataType::Int8; - let value_type = DataType::Int8; + fn get_dict_arraydata( + keys: Buffer, + key_type: DataType, + value_data: ArrayData, + ) -> ArrayData { + let value_type = value_data.data_type().clone(); let dict_data_type = DataType::Dictionary(Box::new(key_type), Box::new(value_type)); - let dict_data = ArrayData::builder(dict_data_type.clone()) + ArrayData::builder(dict_data_type.clone()) .len(3) .add_buffer(keys.clone()) .add_child_data(value_data.clone()) .build() - .unwrap(); - Int8DictionaryArray::from(dict_data) + .unwrap() } #[test] - fn test_eq_dyn_dictionary_array() { + fn test_eq_dyn_dictionary_i8_array() { + let key_type = DataType::Int8; + // Construct a value array + let value_data = ArrayData::builder(DataType::Int8) + .len(8) + .add_buffer(Buffer::from( + &[10_i8, 11, 12, 13, 14, 15, 16, 17].to_byte_slice(), + )) + .build() + .unwrap(); + let keys1 = Buffer::from(&[2_i8, 3, 4].to_byte_slice()); let keys2 = Buffer::from(&[2_i8, 4, 4].to_byte_slice()); - let dict_array1 = get_dictionary_array(keys1); - let dict_array2 = get_dictionary_array(keys2); + let dict_array1: DictionaryArray = Int8DictionaryArray::from( + get_dict_arraydata(keys1, key_type.clone(), value_data.clone()), + ); + let dict_array2: DictionaryArray = Int8DictionaryArray::from( + get_dict_arraydata(keys2, key_type.clone(), value_data.clone()), + ); let result = eq_dyn(&dict_array1, &dict_array2); assert!(result.is_ok()); @@ -4447,4 +4509,33 @@ mod tests { BooleanArray::from(vec![Some(true), Some(false), Some(true)]) ); } + + #[test] + fn test_eq_dyn_dictionary_u64_array() { + let key_type = DataType::UInt64; + // Construct a value array + let value_data = ArrayData::builder(DataType::UInt64) + .len(8) + .add_buffer(Buffer::from( + &[10_u64, 11, 12, 13, 14, 15, 16, 17].to_byte_slice(), + )) + .build() + .unwrap(); + + let keys1 = Buffer::from(&[1_u64, 3, 4].to_byte_slice()); + let keys2 = Buffer::from(&[2_u64, 3, 5].to_byte_slice()); + let dict_array1: DictionaryArray = UInt64DictionaryArray::from( + get_dict_arraydata(keys1, key_type.clone(), value_data.clone()), + ); + let dict_array2: DictionaryArray = UInt64DictionaryArray::from( + get_dict_arraydata(keys2, key_type.clone(), value_data.clone()), + ); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), Some(true), Some(false)]) + ); + } } From 11326f6b65ebb304a929500b91a4f9252047b84a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 Feb 2022 17:14:27 -0800 Subject: [PATCH 04/10] Fix clippy error --- arrow/src/compute/kernels/comparison.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 74be8e6ecada..a9f26d8747aa 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -2169,15 +2169,11 @@ where .iter() .zip(right.keys().iter()) .map(|(left_key, right_key)| { - if left_key.is_none() || right_key.is_none() { - None - } else { - let left_key = left_key - .unwrap() + if let (Some(left_k), Some(right_k)) = (left_key, right_key) { + let left_key = left_k .to_usize() .expect("Dictionary index not usize"); - let right_key = right_key - .unwrap() + let right_key = right_k .to_usize() .expect("Dictionary index not usize"); unsafe { @@ -2185,7 +2181,7 @@ where let right_value = right_values.value_unchecked(right_key); Some(op(left_value, right_value)) } - } + } else { None } }) .collect(); From 9029d46d66e204e768be6031ff3ed4aa8ee9de8b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 Feb 2022 17:20:21 -0800 Subject: [PATCH 05/10] Fix format --- arrow/src/compute/kernels/comparison.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index a9f26d8747aa..5af7fece55e3 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -2170,18 +2170,16 @@ where .zip(right.keys().iter()) .map(|(left_key, right_key)| { if let (Some(left_k), Some(right_k)) = (left_key, right_key) { - let left_key = left_k - .to_usize() - .expect("Dictionary index not usize"); - let right_key = right_k - .to_usize() - .expect("Dictionary index not usize"); + let left_key = left_k.to_usize().expect("Dictionary index not usize"); + let right_key = right_k.to_usize().expect("Dictionary index not usize"); unsafe { let left_value = left_values.value_unchecked(left_key); let right_value = right_values.value_unchecked(right_key); Some(op(left_value, right_value)) } - } else { None } + } else { + None + } }) .collect(); From debbabaff81c80ba5ce0c2db5e64d87c542bb3dc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 Feb 2022 18:01:29 -0800 Subject: [PATCH 06/10] Fix clippy and format --- arrow/src/compute/kernels/comparison.rs | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index a9298c41b80e..737aa125e72e 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -4545,10 +4545,10 @@ mod tests { let value_type = value_data.data_type().clone(); let dict_data_type = DataType::Dictionary(Box::new(key_type), Box::new(value_type)); - ArrayData::builder(dict_data_type.clone()) + ArrayData::builder(dict_data_type) .len(3) - .add_buffer(keys.clone()) - .add_child_data(value_data.clone()) + .add_buffer(keys) + .add_child_data(value_data) .build() .unwrap() } @@ -4570,9 +4570,8 @@ mod tests { let dict_array1: DictionaryArray = Int8DictionaryArray::from( get_dict_arraydata(keys1, key_type.clone(), value_data.clone()), ); - let dict_array2: DictionaryArray = Int8DictionaryArray::from( - get_dict_arraydata(keys2, key_type.clone(), value_data.clone()), - ); + let dict_array2: DictionaryArray = + Int8DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data)); let result = eq_dyn(&dict_array1, &dict_array2); assert!(result.is_ok()); @@ -4599,9 +4598,8 @@ mod tests { let dict_array1: DictionaryArray = UInt64DictionaryArray::from( get_dict_arraydata(keys1, key_type.clone(), value_data.clone()), ); - let dict_array2: DictionaryArray = UInt64DictionaryArray::from( - get_dict_arraydata(keys2, key_type.clone(), value_data.clone()), - ); + let dict_array2: DictionaryArray = + UInt64DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data)); let result = eq_dyn(&dict_array1, &dict_array2); assert!(result.is_ok()); From fefd9ba126d74c945c22046bdca2b9f8bb7acfd2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 Feb 2022 19:58:20 -0800 Subject: [PATCH 07/10] Add cmp_dict_utf8 and cmp_dict_binary to cover the utf8/binary value array cases --- arrow/src/compute/kernels/comparison.rs | 174 ++++++++++++++++-------- 1 file changed, 120 insertions(+), 54 deletions(-) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 737aa125e72e..ba7fe09c405c 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -2031,31 +2031,43 @@ macro_rules! typed_compares { } macro_rules! typed_dict_cmp { - ($LEFT: expr, $RIGHT: expr, $OP_PRIM: expr, $KT: tt) => {{ + ($LEFT: expr, $RIGHT: expr, $OP: expr, $KT: tt) => {{ match ($LEFT.value_type(), $RIGHT.value_type()) { (DataType::Int8, DataType::Int8) => { - cmp_dict::<$KT, Int8Type, _>($LEFT, $RIGHT, $OP_PRIM) + cmp_dict::<$KT, Int8Type, _>($LEFT, $RIGHT, $OP) } (DataType::Int16, DataType::Int16) => { - cmp_dict::<$KT, Int16Type, _>($LEFT, $RIGHT, $OP_PRIM) + cmp_dict::<$KT, Int16Type, _>($LEFT, $RIGHT, $OP) } (DataType::Int32, DataType::Int32) => { - cmp_dict::<$KT, Int32Type, _>($LEFT, $RIGHT, $OP_PRIM) + cmp_dict::<$KT, Int32Type, _>($LEFT, $RIGHT, $OP) } (DataType::Int64, DataType::Int64) => { - cmp_dict::<$KT, Int64Type, _>($LEFT, $RIGHT, $OP_PRIM) + cmp_dict::<$KT, Int64Type, _>($LEFT, $RIGHT, $OP) } (DataType::UInt8, DataType::UInt8) => { - cmp_dict::<$KT, UInt8Type, _>($LEFT, $RIGHT, $OP_PRIM) + cmp_dict::<$KT, UInt8Type, _>($LEFT, $RIGHT, $OP) } (DataType::UInt16, DataType::UInt16) => { - cmp_dict::<$KT, UInt16Type, _>($LEFT, $RIGHT, $OP_PRIM) + cmp_dict::<$KT, UInt16Type, _>($LEFT, $RIGHT, $OP) } (DataType::UInt32, DataType::UInt32) => { - cmp_dict::<$KT, UInt32Type, _>($LEFT, $RIGHT, $OP_PRIM) + cmp_dict::<$KT, UInt32Type, _>($LEFT, $RIGHT, $OP) } (DataType::UInt64, DataType::UInt64) => { - cmp_dict::<$KT, UInt64Type, _>($LEFT, $RIGHT, $OP_PRIM) + cmp_dict::<$KT, UInt64Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::Utf8, DataType::Utf8) => { + cmp_dict_utf8::<$KT, i32, _>($LEFT, $RIGHT, $OP) + } + (DataType::LargeUtf8, DataType::LargeUtf8) => { + cmp_dict_utf8::<$KT, i64, _>($LEFT, $RIGHT, $OP) + } + (DataType::Binary, DataType::Binary) => { + cmp_dict_binary::<$KT, i32, _>($LEFT, $RIGHT, $OP) + } + (DataType::LargeBinary, DataType::LargeBinary) => { + cmp_dict_binary::<$KT, i64, _>($LEFT, $RIGHT, $OP) } (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( "Comparing dictionary arrays of value type {} is not yet implemented", @@ -2071,49 +2083,49 @@ macro_rules! typed_dict_cmp { macro_rules! typed_dict_compares { // Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray` - ($LEFT: expr, $RIGHT: expr, $OP_PRIM: expr) => {{ + ($LEFT: expr, $RIGHT: expr, $OP: expr) => {{ match ($LEFT.data_type(), $RIGHT.data_type()) { (DataType::Dictionary(left_key_type, _), DataType::Dictionary(right_key_type, _))=> { match (left_key_type.as_ref(), right_key_type.as_ref()) { (DataType::Int8, DataType::Int8) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP_PRIM, Int8Type) + typed_dict_cmp!(left, right, $OP, Int8Type) } (DataType::Int16, DataType::Int16) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP_PRIM, Int16Type) + typed_dict_cmp!(left, right, $OP, Int16Type) } (DataType::Int32, DataType::Int32) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP_PRIM, Int32Type) + typed_dict_cmp!(left, right, $OP, Int32Type) } (DataType::Int64, DataType::Int64) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP_PRIM, Int64Type) + typed_dict_cmp!(left, right, $OP, Int64Type) } (DataType::UInt8, DataType::UInt8) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP_PRIM, UInt8Type) + typed_dict_cmp!(left, right, $OP, UInt8Type) } (DataType::UInt16, DataType::UInt16) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP_PRIM, UInt16Type) + typed_dict_cmp!(left, right, $OP, UInt16Type) } (DataType::UInt32, DataType::UInt32) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP_PRIM, UInt32Type) + typed_dict_cmp!(left, right, $OP, UInt32Type) } (DataType::UInt64, DataType::UInt64) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP_PRIM, UInt64Type) + typed_dict_cmp!(left, right, $OP, UInt64Type) } (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( "Comparing dictionary arrays of type {} is not yet implemented", @@ -2133,6 +2145,47 @@ macro_rules! typed_dict_compares { }}; } +/// Helper function to perform boolean lambda function on values from two dictionary arrays, this +/// version does not attempt to use SIMD. +macro_rules! compare_dict_op { + ($left: expr, $right:expr, $op:expr, $value_ty:ty) => {{ + if $left.len() != $right.len() { + return Err(ArrowError::ComputeError( + "Cannot perform comparison operation on arrays of different length" + .to_string(), + )); + } + let left_values = $left.values().as_any().downcast_ref::<$value_ty>().unwrap(); + let right_values = $right + .values() + .as_any() + .downcast_ref::<$value_ty>() + .unwrap(); + + let result = $left + .keys() + .iter() + .zip($right.keys().iter()) + .map(|(left_key, right_key)| { + if let (Some(left_k), Some(right_k)) = (left_key, right_key) { + let left_key = left_k.to_usize().expect("Dictionary index not usize"); + let right_key = + right_k.to_usize().expect("Dictionary index not usize"); + unsafe { + let left_value = left_values.value_unchecked(left_key); + let right_value = right_values.value_unchecked(right_key); + Some($op(left_value, right_value)) + } + } else { + None + } + }) + .collect(); + + Ok(result) + }}; +} + /// Perform given operation on two `DictionaryArray`s. /// Only when two arrays are of the same type the comparison will happen otherwise it will err /// with a casting error. @@ -2146,44 +2199,35 @@ where T: ArrowNumericType, F: Fn(T::Native, T::Native) -> bool, { - if left.len() != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } + compare_dict_op!(left, right, op, PrimitiveArray) +} - let left_values = left - .values() - .as_any() - .downcast_ref::>() - .unwrap(); - let right_values = right - .values() - .as_any() - .downcast_ref::>() - .unwrap(); - - let result = left - .keys() - .iter() - .zip(right.keys().iter()) - .map(|(left_key, right_key)| { - if let (Some(left_k), Some(right_k)) = (left_key, right_key) { - let left_key = left_k.to_usize().expect("Dictionary index not usize"); - let right_key = right_k.to_usize().expect("Dictionary index not usize"); - unsafe { - let left_value = left_values.value_unchecked(left_key); - let right_value = right_values.value_unchecked(right_key); - Some(op(left_value, right_value)) - } - } else { - None - } - }) - .collect(); +/// Perform the given operation on two `DictionaryArray`s which value type is +/// `DataType::Utf8` or `DataType::LargeUtf8`. +pub fn cmp_dict_utf8( + left: &DictionaryArray, + right: &DictionaryArray, + op: F, +) -> Result +where + K: ArrowNumericType, + F: Fn(&str, &str) -> bool, +{ + compare_dict_op!(left, right, op, GenericStringArray) +} - Ok(result) +/// Perform the given operation on two `DictionaryArray`s which value type is +/// `DataType::Binary` or `DataType::LargeBinary`. +pub fn cmp_dict_binary( + left: &DictionaryArray, + right: &DictionaryArray, + op: F, +) -> Result +where + K: ArrowNumericType, + F: Fn(&[u8], &[u8]) -> bool, +{ + compare_dict_op!(left, right, op, GenericBinaryArray) } /// Perform `left == right` operation on two (dynamic) [`Array`]s. @@ -4608,4 +4652,26 @@ mod tests { BooleanArray::from(vec![Some(false), Some(true), Some(false)]) ); } + + #[test] + fn test_eq_dyn_dictionary_utf8_array() { + let test1 = vec!["a", "a", "b", "c"]; + let test2 = vec!["a", "b", "b", "c"]; + + let dict_array1: DictionaryArray = test1 + .iter() + .map(|&x| if x == "b" { None } else { Some(x) }) + .collect(); + let dict_array2: DictionaryArray = test2 + .iter() + .map(|&x| if x == "b" { None } else { Some(x) }) + .collect(); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, None, Some(true)]) + ); + } } From d769fd9c02cf0eeaf45d401440eec3c545eca56f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 Feb 2022 20:48:07 -0800 Subject: [PATCH 08/10] Add binary test --- arrow/src/compute/kernels/comparison.rs | 34 +++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index ba7fe09c405c..acf98c002d1b 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -4674,4 +4674,38 @@ mod tests { BooleanArray::from(vec![Some(true), None, None, Some(true)]) ); } + + #[test] + fn test_eq_dyn_dictionary_binary_array() { + let key_type = DataType::UInt64; + + // Construct a value array + let values: [u8; 12] = [ + b'h', b'e', b'l', b'l', b'o', b'p', b'a', b'r', b'q', b'u', b'e', b't', + ]; + let offsets: [i32; 4] = [0, 5, 5, 12]; + + // Array data: ["hello", "", "parquet"] + let value_data = ArrayData::builder(DataType::Binary) + .len(3) + .add_buffer(Buffer::from_slice_ref(&offsets)) + .add_buffer(Buffer::from_slice_ref(&values)) + .build() + .unwrap(); + + let keys1 = Buffer::from(&[0_u64, 1, 2].to_byte_slice()); + let keys2 = Buffer::from(&[0_u64, 2, 1].to_byte_slice()); + let dict_array1: DictionaryArray = UInt64DictionaryArray::from( + get_dict_arraydata(keys1, key_type.clone(), value_data.clone()), + ); + let dict_array2: DictionaryArray = + UInt64DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data)); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), Some(false), Some(false)]) + ); + } } From 65f286bb3770935d37b5dbc60cc2b65ba0b97c3f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 9 Feb 2022 00:59:13 -0800 Subject: [PATCH 09/10] Add remaining types --- arrow/src/compute/kernels/comparison.rs | 137 ++++++++++++++++++++++-- 1 file changed, 131 insertions(+), 6 deletions(-) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index acf98c002d1b..c7fdb4831b01 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -2033,6 +2033,9 @@ macro_rules! typed_compares { macro_rules! typed_dict_cmp { ($LEFT: expr, $RIGHT: expr, $OP: expr, $KT: tt) => {{ match ($LEFT.value_type(), $RIGHT.value_type()) { + (DataType::Boolean, DataType::Boolean) => { + cmp_dict_bool::<$KT, _>($LEFT, $RIGHT, $OP) + } (DataType::Int8, DataType::Int8) => { cmp_dict::<$KT, Int8Type, _>($LEFT, $RIGHT, $OP) } @@ -2069,6 +2072,54 @@ macro_rules! typed_dict_cmp { (DataType::LargeBinary, DataType::LargeBinary) => { cmp_dict_binary::<$KT, i64, _>($LEFT, $RIGHT, $OP) } + ( + DataType::Timestamp(TimeUnit::Nanosecond, _), + DataType::Timestamp(TimeUnit::Nanosecond, _), + ) => { + cmp_dict::<$KT, TimestampNanosecondType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Timestamp(TimeUnit::Microsecond, _), + DataType::Timestamp(TimeUnit::Microsecond, _), + ) => { + cmp_dict::<$KT, TimestampMicrosecondType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Timestamp(TimeUnit::Millisecond, _), + DataType::Timestamp(TimeUnit::Millisecond, _), + ) => { + cmp_dict::<$KT, TimestampMillisecondType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Timestamp(TimeUnit::Second, _), + DataType::Timestamp(TimeUnit::Second, _), + ) => { + cmp_dict::<$KT, TimestampSecondType, _>($LEFT, $RIGHT, $OP) + } + (DataType::Date32, DataType::Date32) => { + cmp_dict::<$KT, Date32Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::Date64, DataType::Date64) => { + cmp_dict::<$KT, Date64Type, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Interval(IntervalUnit::YearMonth), + DataType::Interval(IntervalUnit::YearMonth), + ) => { + cmp_dict::<$KT, IntervalYearMonthType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Interval(IntervalUnit::DayTime), + DataType::Interval(IntervalUnit::DayTime), + ) => { + cmp_dict::<$KT, IntervalDayTimeType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Interval(IntervalUnit::MonthDayNano), + DataType::Interval(IntervalUnit::MonthDayNano), + ) => { + cmp_dict::<$KT, IntervalMonthDayNanoType, _>($LEFT, $RIGHT, $OP) + } (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( "Comparing dictionary arrays of value type {} is not yet implemented", t1 @@ -2202,6 +2253,20 @@ where compare_dict_op!(left, right, op, PrimitiveArray) } +/// Perform the given operation on two `DictionaryArray`s which value type is +/// `DataType::Boolean`. +pub fn cmp_dict_bool( + left: &DictionaryArray, + right: &DictionaryArray, + op: F, +) -> Result +where + K: ArrowNumericType, + F: Fn(bool, bool) -> bool, +{ + compare_dict_op!(left, right, op, BooleanArray) +} + /// Perform the given operation on two `DictionaryArray`s which value type is /// `DataType::Utf8` or `DataType::LargeUtf8`. pub fn cmp_dict_utf8( @@ -4619,10 +4684,7 @@ mod tests { let result = eq_dyn(&dict_array1, &dict_array2); assert!(result.is_ok()); - assert_eq!( - result.unwrap(), - BooleanArray::from(vec![Some(true), Some(false), Some(true)]) - ); + assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true])); } #[test] @@ -4649,7 +4711,7 @@ mod tests { assert!(result.is_ok()); assert_eq!( result.unwrap(), - BooleanArray::from(vec![Some(false), Some(true), Some(false)]) + BooleanArray::from(vec![false, true, false]) ); } @@ -4705,7 +4767,70 @@ mod tests { assert!(result.is_ok()); assert_eq!( result.unwrap(), - BooleanArray::from(vec![Some(true), Some(false), Some(false)]) + BooleanArray::from(vec![true, false, false]) + ); + } + + #[test] + fn test_eq_dyn_dictionary_interval_array() { + let key_type = DataType::UInt64; + + let value_array = IntervalDayTimeArray::from(vec![1, 6, 10, 2, 3, 5]); + let value_data = value_array.data().clone(); + + let keys1 = Buffer::from(&[1_u64, 0, 3].to_byte_slice()); + let keys2 = Buffer::from(&[2_u64, 0, 3].to_byte_slice()); + let dict_array1: DictionaryArray = UInt64DictionaryArray::from( + get_dict_arraydata(keys1, key_type.clone(), value_data.clone()), + ); + let dict_array2: DictionaryArray = + UInt64DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data)); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true])); + } + + #[test] + fn test_eq_dyn_dictionary_date_array() { + let key_type = DataType::UInt64; + + let value_array = Date32Array::from(vec![1, 6, 10, 2, 3, 5]); + let value_data = value_array.data().clone(); + + let keys1 = Buffer::from(&[1_u64, 0, 3].to_byte_slice()); + let keys2 = Buffer::from(&[2_u64, 0, 3].to_byte_slice()); + let dict_array1: DictionaryArray = UInt64DictionaryArray::from( + get_dict_arraydata(keys1, key_type.clone(), value_data.clone()), + ); + let dict_array2: DictionaryArray = + UInt64DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data)); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true])); + } + + #[test] + fn test_eq_dyn_dictionary_bool_array() { + let key_type = DataType::UInt64; + + let value_array = BooleanArray::from(vec![true, false]); + let value_data = value_array.data().clone(); + + let keys1 = Buffer::from(&[1_u64, 1, 1].to_byte_slice()); + let keys2 = Buffer::from(&[0_u64, 1, 0].to_byte_slice()); + let dict_array1: DictionaryArray = UInt64DictionaryArray::from( + get_dict_arraydata(keys1, key_type.clone(), value_data.clone()), + ); + let dict_array2: DictionaryArray = + UInt64DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data)); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![false, true, false]) ); } } From d5ffe59f2234a4c964591f7c8136f612abbf4cde Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 10 Feb 2022 09:21:58 -0800 Subject: [PATCH 10/10] Add Float32 and Float64 and update a few comments. --- arrow/src/compute/kernels/comparison.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index c7fdb4831b01..b5331fdc662d 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -2030,6 +2030,7 @@ macro_rules! typed_compares { }}; } +/// Applies $OP to $LEFT and $RIGHT which are two dictionaries which have (the same) key type $KT macro_rules! typed_dict_cmp { ($LEFT: expr, $RIGHT: expr, $OP: expr, $KT: tt) => {{ match ($LEFT.value_type(), $RIGHT.value_type()) { @@ -2060,6 +2061,12 @@ macro_rules! typed_dict_cmp { (DataType::UInt64, DataType::UInt64) => { cmp_dict::<$KT, UInt64Type, _>($LEFT, $RIGHT, $OP) } + (DataType::Float32, DataType::Float32) => { + cmp_dict::<$KT, Float32Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::Float64, DataType::Float64) => { + cmp_dict::<$KT, Float64Type, _>($LEFT, $RIGHT, $OP) + } (DataType::Utf8, DataType::Utf8) => { cmp_dict_utf8::<$KT, i32, _>($LEFT, $RIGHT, $OP) } @@ -2197,7 +2204,7 @@ macro_rules! typed_dict_compares { } /// Helper function to perform boolean lambda function on values from two dictionary arrays, this -/// version does not attempt to use SIMD. +/// version does not attempt to use SIMD explicitly (though the compiler may auto vectorize) macro_rules! compare_dict_op { ($left: expr, $right:expr, $op:expr, $value_ty:ty) => {{ if $left.len() != $right.len() { @@ -2238,8 +2245,7 @@ macro_rules! compare_dict_op { } /// Perform given operation on two `DictionaryArray`s. -/// Only when two arrays are of the same type the comparison will happen otherwise it will err -/// with a casting error. +/// Returns an error if the two arrays have different value type pub fn cmp_dict( left: &DictionaryArray, right: &DictionaryArray,