Skip to content

Commit

Permalink
Fix unary for decimal arithmetic computation
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Dec 15, 2022
1 parent a93859b commit ac7f7b8
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
7 changes: 7 additions & 0 deletions arrow-schema/src/datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,13 @@ impl DataType {
)
}

/// Returns true if this type is decimal: (Decimal*).
#[inline]
pub fn is_decimal(&self) -> bool {
use DataType::*;
matches!(self, Decimal128(_, _) | Decimal256(_, _))
}

/// Returns true if this type is temporal: (Date*, Time*, Duration, or Interval).
#[inline]
pub fn is_temporal(&self) -> bool {
Expand Down
20 changes: 19 additions & 1 deletion arrow/src/compute/kernels/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1633,7 +1633,7 @@ mod tests {
use super::*;
use crate::array::Int32Array;
use crate::compute::{binary_mut, try_binary_mut, try_unary_mut, unary_mut};
use crate::datatypes::{Date64Type, Int32Type, Int8Type};
use crate::datatypes::{Date64Type, Decimal128Type, Int32Type, Int8Type};
use arrow_buffer::i256;
use chrono::NaiveDate;
use half::f16;
Expand Down Expand Up @@ -3226,4 +3226,22 @@ mod tests {
])) as ArrayRef;
assert_eq!(&result, &expected);
}

#[test]
fn test_decimal_add_scalar_dyn() {
let a = Decimal128Array::from(vec![100, 210, 320])
.with_precision_and_scale(38, 2)
.unwrap();

let result = add_scalar_dyn::<Decimal128Type>(&a, 1).unwrap();
let result = as_primitive_array::<Decimal128Type>(&result)
.clone()
.with_precision_and_scale(38, 2)
.unwrap();
let expected = Decimal128Array::from(vec![101, 211, 321])
.with_precision_and_scale(38, 2)
.unwrap();

assert_eq!(&expected, &result);
}
}
17 changes: 11 additions & 6 deletions arrow/src/compute/kernels/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,12 @@ where
T: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<T::Native>,
{
if array.value_type() != T::DATA_TYPE {
if array.value_type() != T::DATA_TYPE
&& !(array.value_type().is_decimal() && T::DATA_TYPE.is_decimal())
{
return Err(ArrowError::CastError(format!(
"Cannot perform the unary operation on dictionary array of value type {}",
"Cannot perform the unary operation of type {} on dictionary array of value type {}",
T::DATA_TYPE,
array.value_type()
)));
}
Expand All @@ -135,14 +138,15 @@ where
downcast_dictionary_array! {
array => unary_dict::<_, F, T>(array, op),
t => {
if t == &T::DATA_TYPE {
if t == &T::DATA_TYPE || (t.is_decimal() && T::DATA_TYPE.is_decimal()) {
Ok(Arc::new(unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
)))
} else {
Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation on array of type {}",
"Cannot perform unary operation of type {} on array of type {}",
T::DATA_TYPE,
t
)))
}
Expand All @@ -166,14 +170,15 @@ where
)))
},
t => {
if t == &T::DATA_TYPE {
if t == &T::DATA_TYPE || (t.is_decimal() && T::DATA_TYPE.is_decimal()) {
Ok(Arc::new(try_unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
)?))
} else {
Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation on array of type {}",
"Cannot perform unary operation of type {} on array of type {}",
T::DATA_TYPE,
t
)))
}
Expand Down

0 comments on commit ac7f7b8

Please sign in to comment.