From c98d21a1e27534f2ef3d3432914baccffa233b56 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Nov 2022 23:14:04 -0800 Subject: [PATCH] Add add_scalar_mut and add_scalar_checked_mut --- arrow-array/src/array/primitive_array.rs | 36 ++++++++++ .../src/builder/null_buffer_builder.rs | 5 ++ arrow-array/src/builder/primitive_builder.rs | 8 +++ arrow/src/compute/kernels/arithmetic.rs | 66 ++++++++++++++++++- arrow/src/compute/kernels/arity.rs | 27 ++++++++ 5 files changed, 141 insertions(+), 1 deletion(-) diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 195e2dc19a1a..5bbda8954e7f 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -469,6 +469,42 @@ impl PrimitiveArray { }) } + /// Applies an unary and fallible function to all valid values in a mutable primitive array. + /// Mutable primitive array means that the buffer is not shared with other arrays. + /// As a result, this mutates the buffer directly without allocating new buffer. + /// + /// This is unlike [`Self::unary_mut`] which will apply an infallible function to all rows + /// regardless of validity, in many cases this will be significantly faster and should + /// be preferred if `op` is infallible. + /// + /// This returns an `Err` for two cases. First is input array is shared buffer with other + /// array. In the case, returned `Err` wraps a `Ok` of input array. Second, if the function + /// encounters an error during applying on values. In the case, returned `Err` wraps an + /// `Err` of the actual error. + /// + /// Note: LLVM is currently unable to effectively vectorize fallible operations + pub fn try_unary_mut( + self, + op: F, + ) -> Result, Result, E>> + where + F: Fn(T::Native) -> Result, + { + let len = self.len(); + let null_count = self.null_count(); + let mut builder = self.into_builder().map_err(|arr| Ok(arr))?; + + let (slice, null_buffer) = builder.as_slice(); + + try_for_each_valid_idx(len, 0, null_count, null_buffer, |idx| { + unsafe { *slice.get_unchecked_mut(idx) = op(*slice.get_unchecked(idx))? }; + Ok::<_, E>(()) + }) + .map_err(|err| Err(err))?; + + Ok(builder.finish()) + } + /// Applies a unary and nullable function to all valid values in a primitive array /// /// This is unlike [`Self::unary`] which will apply an infallible function to all rows diff --git a/arrow-array/src/builder/null_buffer_builder.rs b/arrow-array/src/builder/null_buffer_builder.rs index fef7214d5aa7..aaba7cdf212e 100644 --- a/arrow-array/src/builder/null_buffer_builder.rs +++ b/arrow-array/src/builder/null_buffer_builder.rs @@ -150,6 +150,11 @@ impl NullBufferBuilder { self.bitmap_builder = Some(b); } } + + #[inline] + pub fn as_slice(&self) -> Option<&[u8]> { + self.bitmap_builder.as_ref().map(|b| b.as_slice()) + } } impl NullBufferBuilder { diff --git a/arrow-array/src/builder/primitive_builder.rs b/arrow-array/src/builder/primitive_builder.rs index 55d8bac0189f..5c22b5ae03a2 100644 --- a/arrow-array/src/builder/primitive_builder.rs +++ b/arrow-array/src/builder/primitive_builder.rs @@ -228,6 +228,14 @@ impl PrimitiveBuilder { pub fn values_slice_mut(&mut self) -> &mut [T::Native] { self.values_builder.as_slice_mut() } + + /// Returns the current values buffer and null buffer as a slice + pub fn as_slice(&mut self) -> (&mut [T::Native], Option<&[u8]>) { + ( + self.values_builder.as_slice_mut(), + self.null_buffer_builder.as_slice(), + ) + } } #[cfg(test)] diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index a99a90204b7f..2b60279307cb 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -27,7 +27,8 @@ use crate::array::*; use crate::buffer::MutableBuffer; use crate::compute::kernels::arity::unary; use crate::compute::{ - binary, binary_opt, try_binary, try_unary, try_unary_dyn, unary_dyn, + binary, binary_opt, try_binary, try_unary, try_unary_dyn, try_unary_mut, unary_dyn, + unary_mut, }; use crate::datatypes::{ ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type, @@ -914,6 +915,47 @@ where Ok(unary(array, |value| value.add_wrapping(scalar))) } +/// Mutate an array by adding every value in an array by a scalar. If any value in the array +/// is null then the result is also null. +/// +/// This only mutates the array if it is not shared buffers with other arrays. For shared +/// array, it returns an `Err` which wraps input array. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `add_scalar_checked_mut` instead. +pub fn add_scalar_mut( + array: PrimitiveArray, + scalar: T::Native, +) -> std::result::Result, PrimitiveArray> +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + unary_mut(array, |value| value.add_wrapping(scalar)) +} + +/// Mutate an array by adding every value in an array by a scalar. If any value in the array +/// is null then the result is also null. +/// +/// This only mutates the array if it is not shared buffers with other arrays. For shared +/// array, it returns an `Err` which wraps input array with a `Ok`. +/// +/// This detects overflow and returns an `Err` which wraps an `Erro` of actual error. +/// For an non-overflow-checking variant, use `add_scalar_mut` instead. +pub fn add_scalar_checked_mut( + array: PrimitiveArray, + scalar: T::Native, +) -> std::result::Result< + PrimitiveArray, + std::result::Result, ArrowError>, +> +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + try_unary_mut(array, |value| value.add_checked(scalar)) +} + /// Add every value in an array by a scalar. If any value in the array is null then the /// result is also null. /// @@ -3098,4 +3140,26 @@ mod tests { assert_eq!(result.len(), 13); assert_eq!(result.null_count(), 13); } + + #[test] + fn test_primitive_array_add_scalar_mut() { + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = 3; + let c = add_scalar_mut(a, b).unwrap(); + let expected = Int32Array::from(vec![18, 17, 12, 11, 4]); + assert_eq!(c, expected); + } + + #[test] + fn test_primitive_add_scalar_mut_wrapping_overflow() { + let a = Int32Array::from(vec![i32::MAX, i32::MIN]); + + let wrapped = add_scalar_mut(a, 1).unwrap(); + let expected = Int32Array::from(vec![-2147483648, -2147483647]); + assert_eq!(expected, wrapped); + + let a = Int32Array::from(vec![i32::MAX, i32::MIN]); + let overflow = add_scalar_checked_mut(a, 1); + let _ = overflow.expect_err("overflow should be detected"); + } } diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index c99d2b727b8d..848539b27f41 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -58,6 +58,18 @@ where array.unary(op) } +/// See [`PrimitiveArray::unary_mut`] +pub fn unary_mut( + array: PrimitiveArray, + op: F, +) -> std::result::Result, PrimitiveArray> +where + I: ArrowPrimitiveType, + F: Fn(I::Native) -> I::Native, +{ + array.unary_mut(op) +} + /// See [`PrimitiveArray::try_unary`] pub fn try_unary(array: &PrimitiveArray, op: F) -> Result> where @@ -68,6 +80,21 @@ where array.try_unary(op) } +/// See [`PrimitiveArray::try_unary_mut`] +pub fn try_unary_mut( + array: PrimitiveArray, + op: F, +) -> std::result::Result< + PrimitiveArray, + std::result::Result, ArrowError>, +> +where + I: ArrowPrimitiveType, + F: Fn(I::Native) -> Result, +{ + array.try_unary_mut(op) +} + /// A helper function that applies an infallible unary function to a dictionary array with primitive value type. fn unary_dict(array: &DictionaryArray, op: F) -> Result where