From c98d21a1e27534f2ef3d3432914baccffa233b56 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Nov 2022 23:14:04 -0800 Subject: [PATCH 1/6] Add add_scalar_mut and add_scalar_checked_mut --- arrow-array/src/array/primitive_array.rs | 36 ++++++++++ .../src/builder/null_buffer_builder.rs | 5 ++ arrow-array/src/builder/primitive_builder.rs | 8 +++ arrow/src/compute/kernels/arithmetic.rs | 66 ++++++++++++++++++- arrow/src/compute/kernels/arity.rs | 27 ++++++++ 5 files changed, 141 insertions(+), 1 deletion(-) diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 195e2dc19a1a..5bbda8954e7f 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -469,6 +469,42 @@ impl PrimitiveArray { }) } + /// Applies an unary and fallible function to all valid values in a mutable primitive array. + /// Mutable primitive array means that the buffer is not shared with other arrays. + /// As a result, this mutates the buffer directly without allocating new buffer. + /// + /// This is unlike [`Self::unary_mut`] which will apply an infallible function to all rows + /// regardless of validity, in many cases this will be significantly faster and should + /// be preferred if `op` is infallible. + /// + /// This returns an `Err` for two cases. First is input array is shared buffer with other + /// array. In the case, returned `Err` wraps a `Ok` of input array. Second, if the function + /// encounters an error during applying on values. In the case, returned `Err` wraps an + /// `Err` of the actual error. + /// + /// Note: LLVM is currently unable to effectively vectorize fallible operations + pub fn try_unary_mut( + self, + op: F, + ) -> Result, Result, E>> + where + F: Fn(T::Native) -> Result, + { + let len = self.len(); + let null_count = self.null_count(); + let mut builder = self.into_builder().map_err(|arr| Ok(arr))?; + + let (slice, null_buffer) = builder.as_slice(); + + try_for_each_valid_idx(len, 0, null_count, null_buffer, |idx| { + unsafe { *slice.get_unchecked_mut(idx) = op(*slice.get_unchecked(idx))? }; + Ok::<_, E>(()) + }) + .map_err(|err| Err(err))?; + + Ok(builder.finish()) + } + /// Applies a unary and nullable function to all valid values in a primitive array /// /// This is unlike [`Self::unary`] which will apply an infallible function to all rows diff --git a/arrow-array/src/builder/null_buffer_builder.rs b/arrow-array/src/builder/null_buffer_builder.rs index fef7214d5aa7..aaba7cdf212e 100644 --- a/arrow-array/src/builder/null_buffer_builder.rs +++ b/arrow-array/src/builder/null_buffer_builder.rs @@ -150,6 +150,11 @@ impl NullBufferBuilder { self.bitmap_builder = Some(b); } } + + #[inline] + pub fn as_slice(&self) -> Option<&[u8]> { + self.bitmap_builder.as_ref().map(|b| b.as_slice()) + } } impl NullBufferBuilder { diff --git a/arrow-array/src/builder/primitive_builder.rs b/arrow-array/src/builder/primitive_builder.rs index 55d8bac0189f..5c22b5ae03a2 100644 --- a/arrow-array/src/builder/primitive_builder.rs +++ b/arrow-array/src/builder/primitive_builder.rs @@ -228,6 +228,14 @@ impl PrimitiveBuilder { pub fn values_slice_mut(&mut self) -> &mut [T::Native] { self.values_builder.as_slice_mut() } + + /// Returns the current values buffer and null buffer as a slice + pub fn as_slice(&mut self) -> (&mut [T::Native], Option<&[u8]>) { + ( + self.values_builder.as_slice_mut(), + self.null_buffer_builder.as_slice(), + ) + } } #[cfg(test)] diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index a99a90204b7f..2b60279307cb 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -27,7 +27,8 @@ use crate::array::*; use crate::buffer::MutableBuffer; use crate::compute::kernels::arity::unary; use crate::compute::{ - binary, binary_opt, try_binary, try_unary, try_unary_dyn, unary_dyn, + binary, binary_opt, try_binary, try_unary, try_unary_dyn, try_unary_mut, unary_dyn, + unary_mut, }; use crate::datatypes::{ ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type, @@ -914,6 +915,47 @@ where Ok(unary(array, |value| value.add_wrapping(scalar))) } +/// Mutate an array by adding every value in an array by a scalar. If any value in the array +/// is null then the result is also null. +/// +/// This only mutates the array if it is not shared buffers with other arrays. For shared +/// array, it returns an `Err` which wraps input array. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `add_scalar_checked_mut` instead. +pub fn add_scalar_mut( + array: PrimitiveArray, + scalar: T::Native, +) -> std::result::Result, PrimitiveArray> +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + unary_mut(array, |value| value.add_wrapping(scalar)) +} + +/// Mutate an array by adding every value in an array by a scalar. If any value in the array +/// is null then the result is also null. +/// +/// This only mutates the array if it is not shared buffers with other arrays. For shared +/// array, it returns an `Err` which wraps input array with a `Ok`. +/// +/// This detects overflow and returns an `Err` which wraps an `Erro` of actual error. +/// For an non-overflow-checking variant, use `add_scalar_mut` instead. +pub fn add_scalar_checked_mut( + array: PrimitiveArray, + scalar: T::Native, +) -> std::result::Result< + PrimitiveArray, + std::result::Result, ArrowError>, +> +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + try_unary_mut(array, |value| value.add_checked(scalar)) +} + /// Add every value in an array by a scalar. If any value in the array is null then the /// result is also null. /// @@ -3098,4 +3140,26 @@ mod tests { assert_eq!(result.len(), 13); assert_eq!(result.null_count(), 13); } + + #[test] + fn test_primitive_array_add_scalar_mut() { + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = 3; + let c = add_scalar_mut(a, b).unwrap(); + let expected = Int32Array::from(vec![18, 17, 12, 11, 4]); + assert_eq!(c, expected); + } + + #[test] + fn test_primitive_add_scalar_mut_wrapping_overflow() { + let a = Int32Array::from(vec![i32::MAX, i32::MIN]); + + let wrapped = add_scalar_mut(a, 1).unwrap(); + let expected = Int32Array::from(vec![-2147483648, -2147483647]); + assert_eq!(expected, wrapped); + + let a = Int32Array::from(vec![i32::MAX, i32::MIN]); + let overflow = add_scalar_checked_mut(a, 1); + let _ = overflow.expect_err("overflow should be detected"); + } } diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index c99d2b727b8d..848539b27f41 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -58,6 +58,18 @@ where array.unary(op) } +/// See [`PrimitiveArray::unary_mut`] +pub fn unary_mut( + array: PrimitiveArray, + op: F, +) -> std::result::Result, PrimitiveArray> +where + I: ArrowPrimitiveType, + F: Fn(I::Native) -> I::Native, +{ + array.unary_mut(op) +} + /// See [`PrimitiveArray::try_unary`] pub fn try_unary(array: &PrimitiveArray, op: F) -> Result> where @@ -68,6 +80,21 @@ where array.try_unary(op) } +/// See [`PrimitiveArray::try_unary_mut`] +pub fn try_unary_mut( + array: PrimitiveArray, + op: F, +) -> std::result::Result< + PrimitiveArray, + std::result::Result, ArrowError>, +> +where + I: ArrowPrimitiveType, + F: Fn(I::Native) -> Result, +{ + array.try_unary_mut(op) +} + /// A helper function that applies an infallible unary function to a dictionary array with primitive value type. fn unary_dict(array: &DictionaryArray, op: F) -> Result where From 1b68fdf183dc913db75ac9642bc3ba14084ff464 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 23 Nov 2022 14:21:01 +0800 Subject: [PATCH 2/6] Update slice related functions for completeness. --- arrow-array/src/array/primitive_array.rs | 4 ++-- .../src/builder/boolean_buffer_builder.rs | 5 +++++ arrow-array/src/builder/null_buffer_builder.rs | 5 +++++ arrow-array/src/builder/primitive_builder.rs | 16 ++++++++++++++-- 4 files changed, 26 insertions(+), 4 deletions(-) diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 5bbda8954e7f..767cfcecb44b 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -494,9 +494,9 @@ impl PrimitiveArray { let null_count = self.null_count(); let mut builder = self.into_builder().map_err(|arr| Ok(arr))?; - let (slice, null_buffer) = builder.as_slice(); + let (slice, null_buffer) = builder.slices_mut(); - try_for_each_valid_idx(len, 0, null_count, null_buffer, |idx| { + try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| { unsafe { *slice.get_unchecked_mut(idx) = op(*slice.get_unchecked(idx))? }; Ok::<_, E>(()) }) diff --git a/arrow-array/src/builder/boolean_buffer_builder.rs b/arrow-array/src/builder/boolean_buffer_builder.rs index 2ab01ccfe40b..2e1fa7c88bbd 100644 --- a/arrow-array/src/builder/boolean_buffer_builder.rs +++ b/arrow-array/src/builder/boolean_buffer_builder.rs @@ -156,6 +156,11 @@ impl BooleanBufferBuilder { self.buffer.as_slice() } + /// Returns the packed bits + pub fn as_slice_mut(&mut self) -> &mut [u8] { + self.buffer.as_slice_mut() + } + #[inline] pub fn finish(&mut self) -> Buffer { let buf = std::mem::replace(&mut self.buffer, MutableBuffer::new(0)); diff --git a/arrow-array/src/builder/null_buffer_builder.rs b/arrow-array/src/builder/null_buffer_builder.rs index aaba7cdf212e..5192fa215669 100644 --- a/arrow-array/src/builder/null_buffer_builder.rs +++ b/arrow-array/src/builder/null_buffer_builder.rs @@ -152,9 +152,14 @@ impl NullBufferBuilder { } #[inline] + #[allow(dead_code)] pub fn as_slice(&self) -> Option<&[u8]> { self.bitmap_builder.as_ref().map(|b| b.as_slice()) } + + pub fn as_slice_mut(&mut self) -> Option<&mut [u8]> { + self.bitmap_builder.as_mut().map(|b| b.as_slice_mut()) + } } impl NullBufferBuilder { diff --git a/arrow-array/src/builder/primitive_builder.rs b/arrow-array/src/builder/primitive_builder.rs index 5c22b5ae03a2..92f4ec492d6b 100644 --- a/arrow-array/src/builder/primitive_builder.rs +++ b/arrow-array/src/builder/primitive_builder.rs @@ -229,11 +229,23 @@ impl PrimitiveBuilder { self.values_builder.as_slice_mut() } + /// Returns the current values buffer as a slice + #[allow(dead_code)] + pub fn validity_slice(&self) -> Option<&[u8]> { + self.null_buffer_builder.as_slice() + } + + /// Returns the current values buffer as a mutable slice + #[allow(dead_code)] + pub fn validity_slice_mut(&mut self) -> Option<&mut [u8]> { + self.null_buffer_builder.as_slice_mut() + } + /// Returns the current values buffer and null buffer as a slice - pub fn as_slice(&mut self) -> (&mut [T::Native], Option<&[u8]>) { + pub fn slices_mut(&mut self) -> (&mut [T::Native], Option<&mut [u8]>) { ( self.values_builder.as_slice_mut(), - self.null_buffer_builder.as_slice(), + self.null_buffer_builder.as_slice_mut(), ) } } From 338e60582da3fcbf3b2361677d0a89a3661a3720 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 26 Nov 2022 17:55:00 -0800 Subject: [PATCH 3/6] Change result type --- arrow-array/src/array/primitive_array.rs | 14 ++++++++------ arrow/src/compute/kernels/arithmetic.rs | 4 ++-- arrow/src/compute/kernels/arity.rs | 2 +- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 23af6e063d54..d6a7ff2389be 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -522,23 +522,25 @@ impl PrimitiveArray { pub fn try_unary_mut( self, op: F, - ) -> Result, Result, E>> + ) -> Result, E>, PrimitiveArray> where F: Fn(T::Native) -> Result, { let len = self.len(); let null_count = self.null_count(); - let mut builder = self.into_builder().map_err(|arr| Ok(arr))?; + let mut builder = self.into_builder()?; let (slice, null_buffer) = builder.slices_mut(); - try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| { + match try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| { unsafe { *slice.get_unchecked_mut(idx) = op(*slice.get_unchecked(idx))? }; Ok::<_, E>(()) - }) - .map_err(|err| Err(err))?; + }) { + Ok(_) => {} + Err(err) => return Ok(Err(err)), + }; - Ok(builder.finish()) + Ok(Ok(builder.finish())) } /// Applies a unary and nullable function to all valid values in a primitive array diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 2b60279307cb..7f49b99dc139 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -946,8 +946,8 @@ pub fn add_scalar_checked_mut( array: PrimitiveArray, scalar: T::Native, ) -> std::result::Result< - PrimitiveArray, std::result::Result, ArrowError>, + PrimitiveArray, > where T: ArrowNumericType, @@ -3160,6 +3160,6 @@ mod tests { let a = Int32Array::from(vec![i32::MAX, i32::MIN]); let overflow = add_scalar_checked_mut(a, 1); - let _ = overflow.expect_err("overflow should be detected"); + let _ = overflow.unwrap().expect_err("overflow should be detected"); } } diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index 848539b27f41..946d15e9e984 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -85,8 +85,8 @@ pub fn try_unary_mut( array: PrimitiveArray, op: F, ) -> std::result::Result< - PrimitiveArray, std::result::Result, ArrowError>, + PrimitiveArray, > where I: ArrowPrimitiveType, From 1b3be81b0b71356d710a9f925439a0d25cd7a761 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 26 Nov 2022 17:58:14 -0800 Subject: [PATCH 4/6] Update API doc --- arrow-array/src/array/primitive_array.rs | 8 ++++---- arrow/src/compute/kernels/arithmetic.rs | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index d6a7ff2389be..036ef0cdd52f 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -513,10 +513,10 @@ impl PrimitiveArray { /// regardless of validity, in many cases this will be significantly faster and should /// be preferred if `op` is infallible. /// - /// This returns an `Err` for two cases. First is input array is shared buffer with other - /// array. In the case, returned `Err` wraps a `Ok` of input array. Second, if the function - /// encounters an error during applying on values. In the case, returned `Err` wraps an - /// `Err` of the actual error. + /// This returns an `Err` when the input array is shared buffer with other + /// array. In the case, returned `Err` wraps input array. If the function + /// encounters an error during applying on values. In the case, this returns an `Err` within + /// an `Ok` which wraps the actual error. /// /// Note: LLVM is currently unable to effectively vectorize fallible operations pub fn try_unary_mut( diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 7f49b99dc139..cde4da0df9d4 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -938,10 +938,10 @@ where /// is null then the result is also null. /// /// This only mutates the array if it is not shared buffers with other arrays. For shared -/// array, it returns an `Err` which wraps input array with a `Ok`. +/// array, it returns an `Err` which wraps input array. /// -/// This detects overflow and returns an `Err` which wraps an `Erro` of actual error. -/// For an non-overflow-checking variant, use `add_scalar_mut` instead. +/// This detects overflow and returns an `Err` within an `Ok` which wraps an `Error` of +/// actual error. For an non-overflow-checking variant, use `add_scalar_mut` instead. pub fn add_scalar_checked_mut( array: PrimitiveArray, scalar: T::Native, From 9f07fa6b894653eab9b9edbd159bfe6176f9b607 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 28 Nov 2022 09:02:20 -0800 Subject: [PATCH 5/6] Remove _mut arithmetic kernels --- arrow/src/compute/kernels/arithmetic.rs | 55 ++++--------------------- 1 file changed, 7 insertions(+), 48 deletions(-) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index cde4da0df9d4..f9deada5389b 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -27,8 +27,7 @@ use crate::array::*; use crate::buffer::MutableBuffer; use crate::compute::kernels::arity::unary; use crate::compute::{ - binary, binary_opt, try_binary, try_unary, try_unary_dyn, try_unary_mut, unary_dyn, - unary_mut, + binary, binary_opt, try_binary, try_unary, try_unary_dyn, unary_dyn, }; use crate::datatypes::{ ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type, @@ -915,47 +914,6 @@ where Ok(unary(array, |value| value.add_wrapping(scalar))) } -/// Mutate an array by adding every value in an array by a scalar. If any value in the array -/// is null then the result is also null. -/// -/// This only mutates the array if it is not shared buffers with other arrays. For shared -/// array, it returns an `Err` which wraps input array. -/// -/// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `add_scalar_checked_mut` instead. -pub fn add_scalar_mut( - array: PrimitiveArray, - scalar: T::Native, -) -> std::result::Result, PrimitiveArray> -where - T: ArrowNumericType, - T::Native: ArrowNativeTypeOp, -{ - unary_mut(array, |value| value.add_wrapping(scalar)) -} - -/// Mutate an array by adding every value in an array by a scalar. If any value in the array -/// is null then the result is also null. -/// -/// This only mutates the array if it is not shared buffers with other arrays. For shared -/// array, it returns an `Err` which wraps input array. -/// -/// This detects overflow and returns an `Err` within an `Ok` which wraps an `Error` of -/// actual error. For an non-overflow-checking variant, use `add_scalar_mut` instead. -pub fn add_scalar_checked_mut( - array: PrimitiveArray, - scalar: T::Native, -) -> std::result::Result< - std::result::Result, ArrowError>, - PrimitiveArray, -> -where - T: ArrowNumericType, - T::Native: ArrowNativeTypeOp, -{ - try_unary_mut(array, |value| value.add_checked(scalar)) -} - /// Add every value in an array by a scalar. If any value in the array is null then the /// result is also null. /// @@ -1666,6 +1624,7 @@ where mod tests { use super::*; use crate::array::Int32Array; + use crate::compute::{try_unary_mut, unary_mut}; use crate::datatypes::{Date64Type, Int32Type, Int8Type}; use arrow_buffer::i256; use chrono::NaiveDate; @@ -3142,24 +3101,24 @@ mod tests { } #[test] - fn test_primitive_array_add_scalar_mut() { + fn test_primitive_add_scalar_by_unary_mut() { let a = Int32Array::from(vec![15, 14, 9, 8, 1]); let b = 3; - let c = add_scalar_mut(a, b).unwrap(); + let c = unary_mut(a, |value| value.add_wrapping(b)).unwrap(); let expected = Int32Array::from(vec![18, 17, 12, 11, 4]); assert_eq!(c, expected); } #[test] - fn test_primitive_add_scalar_mut_wrapping_overflow() { + fn test_primitive_add_scalar_overflow_by_try_unary_mut() { let a = Int32Array::from(vec![i32::MAX, i32::MIN]); - let wrapped = add_scalar_mut(a, 1).unwrap(); + let wrapped = unary_mut(a, |value| value.add_wrapping(1)).unwrap(); let expected = Int32Array::from(vec![-2147483648, -2147483647]); assert_eq!(expected, wrapped); let a = Int32Array::from(vec![i32::MAX, i32::MIN]); - let overflow = add_scalar_checked_mut(a, 1); + let overflow = try_unary_mut(a, |value| value.add_checked(1)); let _ = overflow.unwrap().expect_err("overflow should be detected"); } } From 9220b76a96a10b408efad85aea14de4260d55e4e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 28 Nov 2022 09:36:55 -0800 Subject: [PATCH 6/6] For review --- arrow-array/src/builder/primitive_builder.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/arrow-array/src/builder/primitive_builder.rs b/arrow-array/src/builder/primitive_builder.rs index 73d8411ca0e6..fa1dc3ad1264 100644 --- a/arrow-array/src/builder/primitive_builder.rs +++ b/arrow-array/src/builder/primitive_builder.rs @@ -287,13 +287,11 @@ impl PrimitiveBuilder { } /// Returns the current values buffer as a slice - #[allow(dead_code)] pub fn validity_slice(&self) -> Option<&[u8]> { self.null_buffer_builder.as_slice() } /// Returns the current values buffer as a mutable slice - #[allow(dead_code)] pub fn validity_slice_mut(&mut self) -> Option<&mut [u8]> { self.null_buffer_builder.as_slice_mut() }