-
Notifications
You must be signed in to change notification settings - Fork 866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add negate kernels (#4488) #4494
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -74,6 +74,97 @@ pub fn rem(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> { | |||||
arithmetic_op(Op::Rem, lhs, rhs) | ||||||
} | ||||||
|
||||||
macro_rules! neg_checked { | ||||||
($t:ty, $a:ident) => {{ | ||||||
let array = $a | ||||||
.as_primitive::<$t>() | ||||||
.try_unary::<_, $t, _>(|x| x.neg_checked())?; | ||||||
Ok(Arc::new(array)) | ||||||
}}; | ||||||
} | ||||||
|
||||||
macro_rules! neg_wrapping { | ||||||
($t:ty, $a:ident) => {{ | ||||||
let array = $a.as_primitive::<$t>().unary::<_, $t>(|x| x.neg_wrapping()); | ||||||
Ok(Arc::new(array)) | ||||||
}}; | ||||||
} | ||||||
|
||||||
/// Perform `!array`, returning an error on overflow | ||||||
/// | ||||||
/// Note: negation of unsigned arrays is not supported and will return in an error, | ||||||
/// for wrapping unsigned negation consider using [`neg_wrapping()`] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is necessary to avoid ambiguity, will change to a link There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry it just looked strange to me |
||||||
pub fn neg(array: &dyn Array) -> Result<ArrayRef, ArrowError> { | ||||||
use DataType::*; | ||||||
use IntervalUnit::*; | ||||||
use TimeUnit::*; | ||||||
|
||||||
match array.data_type() { | ||||||
Int8 => neg_checked!(Int8Type, array), | ||||||
Int16 => neg_checked!(Int16Type, array), | ||||||
Int32 => neg_checked!(Int32Type, array), | ||||||
Int64 => neg_checked!(Int64Type, array), | ||||||
Float16 => neg_wrapping!(Float16Type, array), | ||||||
Float32 => neg_wrapping!(Float32Type, array), | ||||||
Float64 => neg_wrapping!(Float64Type, array), | ||||||
Decimal128(p, s) => { | ||||||
let a = array | ||||||
.as_primitive::<Decimal128Type>() | ||||||
.try_unary::<_, Decimal128Type, _>(|x| x.neg_checked())?; | ||||||
|
||||||
Ok(Arc::new(a.with_precision_and_scale(*p, *s)?)) | ||||||
} | ||||||
Decimal256(p, s) => { | ||||||
let a = array | ||||||
.as_primitive::<Decimal256Type>() | ||||||
.try_unary::<_, Decimal256Type, _>(|x| x.neg_checked())?; | ||||||
|
||||||
Ok(Arc::new(a.with_precision_and_scale(*p, *s)?)) | ||||||
} | ||||||
Duration(Second) => neg_checked!(DurationSecondType, array), | ||||||
Duration(Millisecond) => neg_checked!(DurationMillisecondType, array), | ||||||
Duration(Microsecond) => neg_checked!(DurationMicrosecondType, array), | ||||||
Duration(Nanosecond) => neg_checked!(DurationNanosecondType, array), | ||||||
Interval(YearMonth) => neg_checked!(IntervalYearMonthType, array), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I double checked that YearMonth intervals are stored as number of whole intervals and thus don't need to be treated field by field |
||||||
Interval(DayTime) => { | ||||||
let a = array | ||||||
.as_primitive::<IntervalDayTimeType>() | ||||||
.try_unary::<_, IntervalDayTimeType, ArrowError>(|x| { | ||||||
let (days, ms) = IntervalDayTimeType::to_parts(x); | ||||||
Ok(IntervalDayTimeType::make_value( | ||||||
days.neg_checked()?, | ||||||
ms.neg_checked()?, | ||||||
)) | ||||||
})?; | ||||||
Ok(Arc::new(a)) | ||||||
} | ||||||
Interval(MonthDayNano) => { | ||||||
let a = array | ||||||
.as_primitive::<IntervalMonthDayNanoType>() | ||||||
.try_unary::<_, IntervalMonthDayNanoType, ArrowError>(|x| { | ||||||
let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(x); | ||||||
Ok(IntervalMonthDayNanoType::make_value( | ||||||
months.neg_checked()?, | ||||||
days.neg_checked()?, | ||||||
nanos.neg_checked()?, | ||||||
)) | ||||||
})?; | ||||||
Ok(Arc::new(a)) | ||||||
} | ||||||
t => Err(ArrowError::InvalidArgumentError(format!( | ||||||
"Invalid arithmetic operation: !{t}" | ||||||
))), | ||||||
} | ||||||
} | ||||||
|
||||||
/// Perform `!array`, wrapping on overflow for [`DataType::is_integer`] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
pub fn neg_wrapping(array: &dyn Array) -> Result<ArrayRef, ArrowError> { | ||||||
downcast_integer! { | ||||||
array.data_type() => (neg_wrapping, array), | ||||||
_ => neg(array), | ||||||
} | ||||||
} | ||||||
|
||||||
/// An enumeration of arithmetic operations | ||||||
/// | ||||||
/// This allows sharing the type dispatch logic across the various kernels | ||||||
|
@@ -670,3 +761,148 @@ fn decimal_op<T: DecimalType>( | |||||
|
||||||
Ok(Arc::new(array)) | ||||||
} | ||||||
|
||||||
#[cfg(test)] | ||||||
mod tests { | ||||||
use super::*; | ||||||
use arrow_buffer::{i256, ScalarBuffer}; | ||||||
|
||||||
fn test_neg_primitive<T: ArrowPrimitiveType>( | ||||||
input: &[T::Native], | ||||||
out: Result<&[T::Native], &str>, | ||||||
) { | ||||||
let a = PrimitiveArray::<T>::new(ScalarBuffer::from(input.to_vec()), None); | ||||||
match out { | ||||||
Ok(expected) => { | ||||||
let result = neg(&a).unwrap(); | ||||||
assert_eq!(result.as_primitive::<T>().values(), expected); | ||||||
} | ||||||
Err(e) => { | ||||||
let err = neg(&a).unwrap_err().to_string(); | ||||||
assert_eq!(e, err); | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
#[test] | ||||||
fn test_neg() { | ||||||
let input = &[1, -5, 2, 693, 3929]; | ||||||
let output = &[-1, 5, -2, -693, -3929]; | ||||||
test_neg_primitive::<Int32Type>(input, Ok(output)); | ||||||
|
||||||
let input = &[1, -5, 2, 693, 3929]; | ||||||
let output = &[-1, 5, -2, -693, -3929]; | ||||||
test_neg_primitive::<Int64Type>(input, Ok(output)); | ||||||
test_neg_primitive::<DurationSecondType>(input, Ok(output)); | ||||||
test_neg_primitive::<DurationMillisecondType>(input, Ok(output)); | ||||||
test_neg_primitive::<DurationMicrosecondType>(input, Ok(output)); | ||||||
test_neg_primitive::<DurationNanosecondType>(input, Ok(output)); | ||||||
|
||||||
let input = &[f32::MAX, f32::MIN, f32::INFINITY, 1.3, 0.5]; | ||||||
let output = &[f32::MIN, f32::MAX, f32::NEG_INFINITY, -1.3, -0.5]; | ||||||
test_neg_primitive::<Float32Type>(input, Ok(output)); | ||||||
|
||||||
test_neg_primitive::<Int32Type>( | ||||||
&[i32::MIN], | ||||||
Err("Compute error: Overflow happened on: -2147483648"), | ||||||
); | ||||||
test_neg_primitive::<Int64Type>( | ||||||
&[i64::MIN], | ||||||
Err("Compute error: Overflow happened on: -9223372036854775808"), | ||||||
); | ||||||
test_neg_primitive::<DurationSecondType>( | ||||||
&[i64::MIN], | ||||||
Err("Compute error: Overflow happened on: -9223372036854775808"), | ||||||
); | ||||||
|
||||||
let r = neg_wrapping(&Int32Array::from(vec![i32::MIN])).unwrap(); | ||||||
assert_eq!(r.as_primitive::<Int32Type>().value(0), i32::MIN); | ||||||
|
||||||
let r = neg_wrapping(&Int64Array::from(vec![i64::MIN])).unwrap(); | ||||||
assert_eq!(r.as_primitive::<Int64Type>().value(0), i64::MIN); | ||||||
|
||||||
let err = neg_wrapping(&DurationSecondArray::from(vec![i64::MIN])) | ||||||
.unwrap_err() | ||||||
.to_string(); | ||||||
|
||||||
assert_eq!( | ||||||
err, | ||||||
"Compute error: Overflow happened on: -9223372036854775808" | ||||||
); | ||||||
|
||||||
let a = Decimal128Array::from(vec![1, 3, -44, 2, 4]) | ||||||
.with_precision_and_scale(9, 6) | ||||||
.unwrap(); | ||||||
|
||||||
let r = neg(&a).unwrap(); | ||||||
assert_eq!(r.data_type(), a.data_type()); | ||||||
assert_eq!( | ||||||
r.as_primitive::<Decimal128Type>().values(), | ||||||
&[-1, -3, 44, -2, -4] | ||||||
); | ||||||
|
||||||
let a = Decimal256Array::from(vec![ | ||||||
i256::from_i128(342), | ||||||
i256::from_i128(-4949), | ||||||
i256::from_i128(3), | ||||||
]) | ||||||
.with_precision_and_scale(9, 6) | ||||||
.unwrap(); | ||||||
|
||||||
let r = neg(&a).unwrap(); | ||||||
assert_eq!(r.data_type(), a.data_type()); | ||||||
assert_eq!( | ||||||
r.as_primitive::<Decimal256Type>().values(), | ||||||
&[ | ||||||
i256::from_i128(-342), | ||||||
i256::from_i128(4949), | ||||||
i256::from_i128(-3), | ||||||
] | ||||||
); | ||||||
|
||||||
let a = IntervalYearMonthArray::from(vec![ | ||||||
IntervalYearMonthType::make_value(2, 4), | ||||||
IntervalYearMonthType::make_value(2, -4), | ||||||
IntervalYearMonthType::make_value(-3, -5), | ||||||
]); | ||||||
let r = neg(&a).unwrap(); | ||||||
assert_eq!( | ||||||
r.as_primitive::<IntervalYearMonthType>().values(), | ||||||
&[ | ||||||
IntervalYearMonthType::make_value(-2, -4), | ||||||
IntervalYearMonthType::make_value(-2, 4), | ||||||
IntervalYearMonthType::make_value(3, 5), | ||||||
] | ||||||
); | ||||||
|
||||||
let a = IntervalDayTimeArray::from(vec![ | ||||||
IntervalDayTimeType::make_value(2, 4), | ||||||
IntervalDayTimeType::make_value(2, -4), | ||||||
IntervalDayTimeType::make_value(-3, -5), | ||||||
]); | ||||||
let r = neg(&a).unwrap(); | ||||||
assert_eq!( | ||||||
r.as_primitive::<IntervalDayTimeType>().values(), | ||||||
&[ | ||||||
IntervalDayTimeType::make_value(-2, -4), | ||||||
IntervalDayTimeType::make_value(-2, 4), | ||||||
IntervalDayTimeType::make_value(3, 5), | ||||||
] | ||||||
); | ||||||
|
||||||
let a = IntervalMonthDayNanoArray::from(vec![ | ||||||
IntervalMonthDayNanoType::make_value(2, 4, 5953394), | ||||||
IntervalMonthDayNanoType::make_value(2, -4, -45839), | ||||||
IntervalMonthDayNanoType::make_value(-3, -5, 6944), | ||||||
]); | ||||||
let r = neg(&a).unwrap(); | ||||||
assert_eq!( | ||||||
r.as_primitive::<IntervalMonthDayNanoType>().values(), | ||||||
&[ | ||||||
IntervalMonthDayNanoType::make_value(-2, -4, -5953394), | ||||||
IntervalMonthDayNanoType::make_value(-2, 4, 45839), | ||||||
IntervalMonthDayNanoType::make_value(3, 5, -6944), | ||||||
] | ||||||
); | ||||||
} | ||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not really sure what "Performs
!array
" means