From 08b53e56e172ebf0887678a1c4159f572c912583 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Sun, 24 Oct 2021 20:17:05 +0800 Subject: [PATCH] implement eq_dyn and neq_dyn --- arrow/src/compute/kernels/comparison.rs | 188 +++++++++++++++++++++--- 1 file changed, 171 insertions(+), 17 deletions(-) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 81827b032b58..1f0cb1a39cd2 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -22,16 +22,19 @@ //! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. -use regex::Regex; -use std::collections::HashMap; - use crate::array::*; use crate::buffer::{bitwise_bin_op_helper, buffer_unary_not, Buffer, MutableBuffer}; use crate::compute::binary_boolean_kernel; use crate::compute::util::combine_option_bitmap; -use crate::datatypes::{ArrowNumericType, DataType}; +use crate::datatypes::{ + ArrowNumericType, DataType, Float32Type, Float64Type, Int16Type, Int32Type, + Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; use crate::error::{ArrowError, Result}; use crate::util::bit_util; +use regex::Regex; +use std::any::type_name; +use std::collections::HashMap; /// Helper function to perform boolean lambda function on values from two arrays, this /// version does not attempt to use SIMD. @@ -974,7 +977,142 @@ where Ok(BooleanArray::from(data)) } -/// Perform `left == right` operation on two arrays. +macro_rules! typed_cmp { + ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident) => {{ + let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| { + ArrowError::CastError(format!( + "Left array cannot be cast to {}", + type_name::<$T>() + )) + })?; + let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| { + ArrowError::CastError(format!( + "Right array cannot be cast to {}", + type_name::<$T>(), + )) + })?; + $OP(left, right) + }}; + ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{ + let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| { + ArrowError::CastError(format!( + "Left array cannot be cast to {}", + type_name::<$T>() + )) + })?; + let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| { + ArrowError::CastError(format!( + "Right array cannot be cast to {}", + type_name::<$T>(), + )) + })?; + $OP::<$TT>(left, right) + }}; +} + +macro_rules! typed_compares { + ($LEFT: expr, $RIGHT: expr, $OP_BOOL: ident, $OP_PRIM: ident, $OP_STR: ident) => {{ + match ($LEFT.data_type(), $RIGHT.data_type()) { + (DataType::Boolean, DataType::Boolean) => { + typed_cmp!($LEFT, $RIGHT, BooleanArray, $OP_BOOL) + } + (DataType::Int8, DataType::Int8) => { + typed_cmp!($LEFT, $RIGHT, Int8Array, $OP_PRIM, Int8Type) + } + (DataType::Int16, DataType::Int16) => { + typed_cmp!($LEFT, $RIGHT, Int16Array, $OP_PRIM, Int16Type) + } + (DataType::Int32, DataType::Int32) => { + typed_cmp!($LEFT, $RIGHT, Int32Array, $OP_PRIM, Int32Type) + } + (DataType::Int64, DataType::Int64) => { + typed_cmp!($LEFT, $RIGHT, Int64Array, $OP_PRIM, Int64Type) + } + (DataType::UInt8, DataType::UInt8) => { + typed_cmp!($LEFT, $RIGHT, UInt8Array, $OP_PRIM, UInt8Type) + } + (DataType::UInt16, DataType::UInt16) => { + typed_cmp!($LEFT, $RIGHT, UInt16Array, $OP_PRIM, UInt16Type) + } + (DataType::UInt32, DataType::UInt32) => { + typed_cmp!($LEFT, $RIGHT, UInt32Array, $OP_PRIM, UInt32Type) + } + (DataType::UInt64, DataType::UInt64) => { + typed_cmp!($LEFT, $RIGHT, UInt64Array, $OP_PRIM, UInt64Type) + } + (DataType::Float32, DataType::Float32) => { + typed_cmp!($LEFT, $RIGHT, Float32Array, $OP_PRIM, Float32Type) + } + (DataType::Float64, DataType::Float64) => { + typed_cmp!($LEFT, $RIGHT, Float64Array, $OP_PRIM, Float64Type) + } + (DataType::Utf8, DataType::Utf8) => { + typed_cmp!($LEFT, $RIGHT, StringArray, $OP_STR, i32) + } + (DataType::LargeUtf8, DataType::LargeUtf8) => { + typed_cmp!($LEFT, $RIGHT, LargeStringArray, $OP_STR, i64) + } + (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( + "Comparing arrays of type {} is not yet implemented", + t1 + ))), + (t1, t2) => Err(ArrowError::CastError(format!( + "Cannot compare two arrays of different types ({} and {})", + t1, t2 + ))), + } + }}; +} + +/// Perform `left == right` operation on two (dynamic) [`Array`]s. +/// +/// Only when two arrays are of the same type the comparison will happen otherwise it will err +/// with a casting error. +pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { + typed_compares!(left, right, eq_bool, eq, eq_utf8) +} + +/// Perform `left != right` operation on two (dynamic) [`Array`]s. +/// +/// Only when two arrays are of the same type the comparison will happen otherwise it will err +/// with a casting error. +pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result { + typed_compares!(left, right, neq_bool, neq, neq_utf8) +} + +/// Perform `left < right` operation on two (dynamic) [`Array`]s. +/// +/// Only when two arrays are of the same type the comparison will happen otherwise it will err +/// with a casting error. +pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result { + typed_compares!(left, right, lt_bool, lt, lt_utf8) +} + +/// Perform `left <= right` operation on two (dynamic) [`Array`]s. +/// +/// Only when two arrays are of the same type the comparison will happen otherwise it will err +/// with a casting error. +pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { + typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8) +} + +/// Perform `left > right` operation on two (dynamic) [`Array`]s. +/// +/// Only when two arrays are of the same type the comparison will happen otherwise it will err +/// with a casting error. +pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result { + typed_compares!(left, right, gt_bool, gt, gt_utf8) +} + +/// Perform `left >= right` operation on two (dynamic) [`Array`]s. +/// +/// Only when two arrays are of the same type the comparison will happen otherwise it will err +/// with a casting error. +pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { + typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8) +} + +/// Perform `left == right` operation on two [`PrimitiveArray`]s. pub fn eq(left: &PrimitiveArray, right: &PrimitiveArray) -> Result where T: ArrowNumericType, @@ -985,7 +1123,7 @@ where return compare_op!(left, right, |a, b| a == b); } -/// Perform `left == right` operation on an array and a scalar value. +/// Perform `left == right` operation on a [`PrimitiveArray`] and a scalar value. pub fn eq_scalar(left: &PrimitiveArray, right: T::Native) -> Result where T: ArrowNumericType, @@ -996,7 +1134,7 @@ where return compare_op_scalar!(left, right, |a, b| a == b); } -/// Perform `left != right` operation on two arrays. +/// Perform `left != right` operation on two [`PrimitiveArray`]s. pub fn neq(left: &PrimitiveArray, right: &PrimitiveArray) -> Result where T: ArrowNumericType, @@ -1007,7 +1145,7 @@ where return compare_op!(left, right, |a, b| a != b); } -/// Perform `left != right` operation on an array and a scalar value. +/// Perform `left != right` operation on a [`PrimitiveArray`] and a scalar value. pub fn neq_scalar(left: &PrimitiveArray, right: T::Native) -> Result where T: ArrowNumericType, @@ -1018,7 +1156,7 @@ where return compare_op_scalar!(left, right, |a, b| a != b); } -/// Perform `left < right` operation on two arrays. Null values are less than non-null +/// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null /// values. pub fn lt(left: &PrimitiveArray, right: &PrimitiveArray) -> Result where @@ -1030,7 +1168,7 @@ where return compare_op!(left, right, |a, b| a < b); } -/// Perform `left < right` operation on an array and a scalar value. +/// Perform `left < right` operation on a [`PrimitiveArray`] and a scalar value. /// Null values are less than non-null values. pub fn lt_scalar(left: &PrimitiveArray, right: T::Native) -> Result where @@ -1042,7 +1180,7 @@ where return compare_op_scalar!(left, right, |a, b| a < b); } -/// Perform `left <= right` operation on two arrays. Null values are less than non-null +/// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null /// values. pub fn lt_eq( left: &PrimitiveArray, @@ -1057,7 +1195,7 @@ where return compare_op!(left, right, |a, b| a <= b); } -/// Perform `left <= right` operation on an array and a scalar value. +/// Perform `left <= right` operation on a [`PrimitiveArray`] and a scalar value. /// Null values are less than non-null values. pub fn lt_eq_scalar(left: &PrimitiveArray, right: T::Native) -> Result where @@ -1069,7 +1207,7 @@ where return compare_op_scalar!(left, right, |a, b| a <= b); } -/// Perform `left > right` operation on two arrays. Non-null values are greater than null +/// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null /// values. pub fn gt(left: &PrimitiveArray, right: &PrimitiveArray) -> Result where @@ -1081,7 +1219,7 @@ where return compare_op!(left, right, |a, b| a > b); } -/// Perform `left > right` operation on an array and a scalar value. +/// Perform `left > right` operation on a [`PrimitiveArray`] and a scalar value. /// Non-null values are greater than null values. pub fn gt_scalar(left: &PrimitiveArray, right: T::Native) -> Result where @@ -1093,7 +1231,7 @@ where return compare_op_scalar!(left, right, |a, b| a > b); } -/// Perform `left >= right` operation on two arrays. Non-null values are greater than null +/// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null /// values. pub fn gt_eq( left: &PrimitiveArray, @@ -1108,7 +1246,7 @@ where return compare_op!(left, right, |a, b| a >= b); } -/// Perform `left >= right` operation on an array and a scalar value. +/// Perform `left >= right` operation on a [`PrimitiveArray`] and a scalar value. /// Non-null values are greater than null values. pub fn gt_eq_scalar(left: &PrimitiveArray, right: T::Native) -> Result where @@ -1260,11 +1398,17 @@ mod tests { /// `EXPECTED` can be either `Vec` or `Vec>`. /// The main reason for this macro is that inputs and outputs align nicely after `cargo fmt`. macro_rules! cmp_i64 { - ($KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => { + ($KERNEL:ident, $DYN_KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => { let a = Int64Array::from($A_VEC); let b = Int64Array::from($B_VEC); let c = $KERNEL(&a, &b).unwrap(); assert_eq!(BooleanArray::from($EXPECTED), c); + + // slice and test if the dynamic array works + let a = a.slice(0, a.len()); + let b = b.slice(0, b.len()); + let c = $DYN_KERNEL(a.as_ref(), b.as_ref()).unwrap(); + assert_eq!(BooleanArray::from($EXPECTED), c); }; } @@ -1284,6 +1428,7 @@ mod tests { fn test_primitive_array_eq() { cmp_i64!( eq, + eq_dyn, vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8], vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10], vec![false, false, true, false, false, false, false, true, false, false] @@ -1330,6 +1475,7 @@ mod tests { fn test_primitive_array_neq() { cmp_i64!( neq, + neq_dyn, vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8], vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10], vec![true, true, false, true, true, true, true, false, true, true] @@ -1479,6 +1625,7 @@ mod tests { fn test_primitive_array_lt() { cmp_i64!( lt, + lt_dyn, vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8], vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10], vec![false, false, false, true, true, false, false, false, true, true] @@ -1499,6 +1646,7 @@ mod tests { fn test_primitive_array_lt_nulls() { cmp_i64!( lt, + lt_dyn, vec![None, None, Some(1), Some(1), None, None, Some(2), Some(2),], vec![None, Some(1), None, Some(1), None, Some(3), None, Some(3),], vec![None, None, None, Some(false), None, None, None, Some(true)] @@ -1519,6 +1667,7 @@ mod tests { fn test_primitive_array_lt_eq() { cmp_i64!( lt_eq, + lt_eq_dyn, vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8], vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10], vec![false, false, true, true, true, false, false, true, true, true] @@ -1539,6 +1688,7 @@ mod tests { fn test_primitive_array_lt_eq_nulls() { cmp_i64!( lt_eq, + lt_eq_dyn, vec![None, None, Some(1), None, None, Some(1), None, None, Some(1)], vec![None, Some(1), Some(0), None, Some(1), Some(2), None, None, Some(3)], vec![None, None, Some(false), None, None, Some(true), None, None, Some(true)] @@ -1559,6 +1709,7 @@ mod tests { fn test_primitive_array_gt() { cmp_i64!( gt, + gt_dyn, vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8], vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10], vec![true, true, false, false, false, true, true, false, false, false] @@ -1579,6 +1730,7 @@ mod tests { fn test_primitive_array_gt_nulls() { cmp_i64!( gt, + gt_dyn, vec![None, None, Some(1), None, None, Some(2), None, None, Some(3)], vec![None, Some(1), Some(1), None, Some(1), Some(1), None, Some(1), Some(1)], vec![None, None, Some(false), None, None, Some(true), None, None, Some(true)] @@ -1599,6 +1751,7 @@ mod tests { fn test_primitive_array_gt_eq() { cmp_i64!( gt_eq, + gt_eq_dyn, vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8], vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10], vec![true, true, true, false, false, true, true, true, false, false] @@ -1619,6 +1772,7 @@ mod tests { fn test_primitive_array_gt_eq_nulls() { cmp_i64!( gt_eq, + gt_eq_dyn, vec![None, None, Some(1), None, Some(1), Some(2), None, None, Some(1)], vec![None, Some(1), None, None, Some(1), Some(1), None, Some(2), Some(2)], vec![None, None, None, None, Some(true), Some(true), None, None, Some(false)]