Skip to content

Commit

Permalink
Refactoring build_compare for decimal and using downcast_primitive (#…
Browse files Browse the repository at this point in the history
…3484)

* Refactor build_compare for decimal and add dict support

* Simplify code using downcast_primitive
  • Loading branch information
viirya authored Jan 10, 2023
1 parent cada9ba commit e8cc351
Showing 1 changed file with 80 additions and 82 deletions.
162 changes: 80 additions & 82 deletions arrow-ord/src/ord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ where
})
}

macro_rules! cmp_dict_primitive_helper {
($t:ty, $key_type_lhs:expr, $left:expr, $right:expr) => {
cmp_dict_primitive::<$t>($key_type_lhs, $left, $right)?
};
}

/// returns a comparison function that compares two values at two different positions
/// between the two arrays.
/// The arrays' types must be equal.
Expand Down Expand Up @@ -193,6 +199,12 @@ pub fn build_compare(
(Int64, Int64) => compare_primitives::<Int64Type>(left, right),
(Float32, Float32) => compare_float::<Float32Type>(left, right),
(Float64, Float64) => compare_float::<Float64Type>(left, right),
(Decimal128(_, _), Decimal128(_, _)) => {
compare_primitives::<Decimal128Type>(left, right)
}
(Decimal256(_, _), Decimal256(_, _)) => {
compare_primitives::<Decimal256Type>(left, right)
}
(Date32, Date32) => compare_primitives::<Date32Type>(left, right),
(Date64, Date64) => compare_primitives::<Date64Type>(left, right),
(Time32(Second), Time32(Second)) => {
Expand Down Expand Up @@ -253,83 +265,8 @@ pub fn build_compare(
}

let key_type_lhs = key_type_lhs.as_ref();

match value_type_lhs.as_ref() {
Int8 => cmp_dict_primitive::<Int8Type>(key_type_lhs, left, right)?,
Int16 => cmp_dict_primitive::<Int16Type>(key_type_lhs, left, right)?,
Int32 => cmp_dict_primitive::<Int32Type>(key_type_lhs, left, right)?,
Int64 => cmp_dict_primitive::<Int64Type>(key_type_lhs, left, right)?,
UInt8 => cmp_dict_primitive::<UInt8Type>(key_type_lhs, left, right)?,
UInt16 => cmp_dict_primitive::<UInt16Type>(key_type_lhs, left, right)?,
UInt32 => cmp_dict_primitive::<UInt32Type>(key_type_lhs, left, right)?,
UInt64 => cmp_dict_primitive::<UInt64Type>(key_type_lhs, left, right)?,
Float32 => cmp_dict_primitive::<Float32Type>(key_type_lhs, left, right)?,
Float64 => cmp_dict_primitive::<Float64Type>(key_type_lhs, left, right)?,
Date32 => cmp_dict_primitive::<Date32Type>(key_type_lhs, left, right)?,
Date64 => cmp_dict_primitive::<Date64Type>(key_type_lhs, left, right)?,
Time32(Second) => {
cmp_dict_primitive::<Time32SecondType>(key_type_lhs, left, right)?
}
Time32(Millisecond) => cmp_dict_primitive::<Time32MillisecondType>(
key_type_lhs,
left,
right,
)?,
Time64(Microsecond) => cmp_dict_primitive::<Time64MicrosecondType>(
key_type_lhs,
left,
right,
)?,
Time64(Nanosecond) => {
cmp_dict_primitive::<Time64NanosecondType>(key_type_lhs, left, right)?
}
Timestamp(Second, _) => {
cmp_dict_primitive::<TimestampSecondType>(key_type_lhs, left, right)?
}
Timestamp(Millisecond, _) => cmp_dict_primitive::<
TimestampMillisecondType,
>(key_type_lhs, left, right)?,
Timestamp(Microsecond, _) => cmp_dict_primitive::<
TimestampMicrosecondType,
>(key_type_lhs, left, right)?,
Timestamp(Nanosecond, _) => {
cmp_dict_primitive::<TimestampNanosecondType>(
key_type_lhs,
left,
right,
)?
}
Interval(YearMonth) => cmp_dict_primitive::<IntervalYearMonthType>(
key_type_lhs,
left,
right,
)?,
Interval(DayTime) => {
cmp_dict_primitive::<IntervalDayTimeType>(key_type_lhs, left, right)?
}
Interval(MonthDayNano) => cmp_dict_primitive::<IntervalMonthDayNanoType>(
key_type_lhs,
left,
right,
)?,
Duration(Second) => {
cmp_dict_primitive::<DurationSecondType>(key_type_lhs, left, right)?
}
Duration(Millisecond) => cmp_dict_primitive::<DurationMillisecondType>(
key_type_lhs,
left,
right,
)?,
Duration(Microsecond) => cmp_dict_primitive::<DurationMicrosecondType>(
key_type_lhs,
left,
right,
)?,
Duration(Nanosecond) => cmp_dict_primitive::<DurationNanosecondType>(
key_type_lhs,
left,
right,
)?,
downcast_primitive! {
value_type_lhs.as_ref() => (cmp_dict_primitive_helper, key_type_lhs, left, right),
Utf8 => match key_type_lhs {
UInt8 => compare_dict_string::<UInt8Type>(left, right),
UInt16 => compare_dict_string::<UInt16Type>(left, right),
Expand All @@ -354,11 +291,6 @@ pub fn build_compare(
}
}
}
(Decimal128(_, _), Decimal128(_, _)) => {
let left: Decimal128Array = Decimal128Array::from(left.data().clone());
let right: Decimal128Array = Decimal128Array::from(right.data().clone());
Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
}
(FixedSizeBinary(_), FixedSizeBinary(_)) => {
let left: FixedSizeBinaryArray =
FixedSizeBinaryArray::from(left.data().clone());
Expand All @@ -380,6 +312,7 @@ pub fn build_compare(
pub mod tests {
use super::*;
use arrow_array::{FixedSizeBinaryArray, Float64Array, Int32Array};
use arrow_buffer::i256;
use std::cmp::Ordering;

#[test]
Expand Down Expand Up @@ -464,6 +397,23 @@ pub mod tests {
assert_eq!(Ordering::Greater, (cmp)(0, 2));
}

#[test]
fn test_decimali256() {
let array = vec![
Some(i256::from_i128(5_i128)),
Some(i256::from_i128(2_i128)),
Some(i256::from_i128(3_i128)),
]
.into_iter()
.collect::<Decimal256Array>()
.with_precision_and_scale(53, 6)
.unwrap();

let cmp = build_compare(&array, &array).unwrap();
assert_eq!(Ordering::Less, (cmp)(1, 0));
assert_eq!(Ordering::Greater, (cmp)(0, 2));
}

#[test]
fn test_dict() {
let data = vec!["a", "b", "c", "a", "a", "c", "c"];
Expand Down Expand Up @@ -584,4 +534,52 @@ pub mod tests {
assert_eq!(Ordering::Greater, (cmp)(3, 1));
assert_eq!(Ordering::Greater, (cmp)(3, 2));
}

#[test]
fn test_decimal_dict() {
let values = Decimal128Array::from(vec![1, 0, 2, 5]);
let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
let array1 = DictionaryArray::<Int8Type>::try_new(&keys, &values).unwrap();

let values = Decimal128Array::from(vec![2, 3, 4, 5]);
let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
let array2 = DictionaryArray::<Int8Type>::try_new(&keys, &values).unwrap();

let cmp = build_compare(&array1, &array2).unwrap();

assert_eq!(Ordering::Less, (cmp)(0, 0));
assert_eq!(Ordering::Less, (cmp)(0, 3));
assert_eq!(Ordering::Equal, (cmp)(3, 3));
assert_eq!(Ordering::Greater, (cmp)(3, 1));
assert_eq!(Ordering::Greater, (cmp)(3, 2));
}

#[test]
fn test_decimal256_dict() {
let values = Decimal256Array::from(vec![
i256::from_i128(1),
i256::from_i128(0),
i256::from_i128(2),
i256::from_i128(5),
]);
let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
let array1 = DictionaryArray::<Int8Type>::try_new(&keys, &values).unwrap();

let values = Decimal256Array::from(vec![
i256::from_i128(2),
i256::from_i128(3),
i256::from_i128(4),
i256::from_i128(5),
]);
let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
let array2 = DictionaryArray::<Int8Type>::try_new(&keys, &values).unwrap();

let cmp = build_compare(&array1, &array2).unwrap();

assert_eq!(Ordering::Less, (cmp)(0, 0));
assert_eq!(Ordering::Less, (cmp)(0, 3));
assert_eq!(Ordering::Equal, (cmp)(3, 3));
assert_eq!(Ordering::Greater, (cmp)(3, 1));
assert_eq!(Ordering::Greater, (cmp)(3, 2));
}
}

0 comments on commit e8cc351

Please sign in to comment.