From f7c6a05263445ba124626447b9b557764f7f3f13 Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Mon, 14 Oct 2024 20:42:40 +1100 Subject: [PATCH] feat: Improved list arithmetic support (#19162) --- crates/polars-arrow/src/bitmap/immutable.rs | 1 + crates/polars-arrow/src/bitmap/mutable.rs | 1 + crates/polars-arrow/src/offset.rs | 49 +- crates/polars-compute/src/arithmetic/mod.rs | 1 + .../polars-compute/src/arithmetic/pl_num.rs | 229 ++++ .../src/series/arithmetic/borrowed.rs | 61 +- .../src/series/arithmetic/list_borrowed.rs | 1147 ++++++++++++++--- .../polars-core/src/series/arithmetic/mod.rs | 1 + crates/polars-core/src/series/ops/reshape.rs | 22 +- crates/polars-core/src/series/series_trait.rs | 21 + crates/polars-expr/src/expressions/binary.rs | 7 +- .../polars-ops/src/series/ops/floor_divide.rs | 4 + crates/polars-plan/src/plans/aexpr/schema.rs | 113 +- .../plans/conversion/type_coercion/binary.rs | 58 +- py-polars/polars/dataframe/frame.py | 5 +- py-polars/polars/series/series.py | 17 +- .../operations/arithmetic/test_arithmetic.py | 50 +- .../arithmetic/test_list_arithmetic.py | 530 ++++++++ 18 files changed, 2034 insertions(+), 283 deletions(-) create mode 100644 crates/polars-compute/src/arithmetic/pl_num.rs create mode 100644 py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py diff --git a/crates/polars-arrow/src/bitmap/immutable.rs b/crates/polars-arrow/src/bitmap/immutable.rs index 127b64b800a5..2ba89e68568a 100644 --- a/crates/polars-arrow/src/bitmap/immutable.rs +++ b/crates/polars-arrow/src/bitmap/immutable.rs @@ -344,6 +344,7 @@ impl Bitmap { /// Unsound iff `i >= self.len()`. #[inline] pub unsafe fn get_bit_unchecked(&self, i: usize) -> bool { + debug_assert!(i < self.len()); get_bit_unchecked(&self.storage, self.offset + i) } diff --git a/crates/polars-arrow/src/bitmap/mutable.rs b/crates/polars-arrow/src/bitmap/mutable.rs index 900722a9e0b7..05bbbe5dd976 100644 --- a/crates/polars-arrow/src/bitmap/mutable.rs +++ b/crates/polars-arrow/src/bitmap/mutable.rs @@ -362,6 +362,7 @@ impl MutableBitmap { /// Caller must ensure that `index < self.len()` #[inline] pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) { + debug_assert!(index < self.len()); let byte = self.buffer.get_unchecked_mut(index / 8); *byte = set_bit_in_byte(*byte, index % 8, value); } diff --git a/crates/polars-arrow/src/offset.rs b/crates/polars-arrow/src/offset.rs index ae4583dfe6f4..ca148655c2e3 100644 --- a/crates/polars-arrow/src/offset.rs +++ b/crates/polars-arrow/src/offset.rs @@ -415,7 +415,7 @@ impl OffsetsBuffer { &self.0 } - /// Returns the length an array with these offsets would be. + /// Returns what the length an array with these offsets would be. #[inline] pub fn len_proxy(&self) -> usize { self.0.len() - 1 @@ -513,6 +513,53 @@ impl OffsetsBuffer { self.0.windows(2).map(|w| (w[1] - w[0]).to_usize()) } + /// Returns `(offset, len)` pairs. + #[inline] + pub fn offset_and_length_iter(&self) -> impl Iterator + '_ { + self.windows(2).map(|x| { + let [l, r] = x else { unreachable!() }; + let l = l.to_usize(); + let r = r.to_usize(); + (l, r - l) + }) + } + + /// Offset and length of the primitive (leaf) array for a double+ nested list for every outer + /// row. + pub fn leaf_ranges_iter( + offsets: &[Self], + ) -> impl Iterator> + '_ { + let others = &offsets[1..]; + + offsets[0].windows(2).map(move |x| { + let [l, r] = x else { unreachable!() }; + let mut l = l.to_usize(); + let mut r = r.to_usize(); + + for o in others { + let slc = o.as_slice(); + l = slc[l].to_usize(); + r = slc[r].to_usize(); + } + + l..r + }) + } + + /// Return the full range of the leaf array used by the list. + pub fn leaf_full_start_end(offsets: &[Self]) -> core::ops::Range { + let mut l = offsets[0].first().to_usize(); + let mut r = offsets[0].last().to_usize(); + + for o in &offsets[1..] { + let slc = o.as_slice(); + l = slc[l].to_usize(); + r = slc[r].to_usize(); + } + + l..r + } + /// Returns the inner [`Buffer`]. #[inline] pub fn into_inner(self) -> Buffer { diff --git a/crates/polars-compute/src/arithmetic/mod.rs b/crates/polars-compute/src/arithmetic/mod.rs index 83471f219d68..cb74881ed4a5 100644 --- a/crates/polars-compute/src/arithmetic/mod.rs +++ b/crates/polars-compute/src/arithmetic/mod.rs @@ -141,5 +141,6 @@ impl ArithmeticKernel for PrimitiveArray { } mod float; +pub mod pl_num; mod signed; mod unsigned; diff --git a/crates/polars-compute/src/arithmetic/pl_num.rs b/crates/polars-compute/src/arithmetic/pl_num.rs new file mode 100644 index 000000000000..c792deacfd52 --- /dev/null +++ b/crates/polars-compute/src/arithmetic/pl_num.rs @@ -0,0 +1,229 @@ +use core::any::TypeId; + +use arrow::types::NativeType; +use polars_utils::floor_divmod::FloorDivMod; + +/// Implements basic arithmetic between scalars with the same behavior as `ArithmeticKernel`. +/// +/// Note, however, that the user is responsible for setting the validity of +/// results for e.g. div/mod operations with 0 in the denominator. +/// +/// This is intended as a low-level utility for custom arithmetic loops +/// (e.g. in list arithmetic). In most cases prefer using `ArithmeticKernel` or +/// `ArithmeticChunked` instead. +pub trait PlNumArithmetic: Sized + Copy + 'static { + type TrueDivT: NativeType; + + fn wrapping_abs(self) -> Self; + fn wrapping_neg(self) -> Self; + fn wrapping_add(self, rhs: Self) -> Self; + fn wrapping_sub(self, rhs: Self) -> Self; + fn wrapping_mul(self, rhs: Self) -> Self; + fn wrapping_floor_div(self, rhs: Self) -> Self; + fn wrapping_trunc_div(self, rhs: Self) -> Self; + fn wrapping_mod(self, rhs: Self) -> Self; + + fn true_div(self, rhs: Self) -> Self::TrueDivT; + + #[inline(always)] + fn legacy_div(self, rhs: Self) -> Self { + if TypeId::of::() == TypeId::of::() { + let ret = self.true_div(rhs); + unsafe { core::mem::transmute_copy(&ret) } + } else { + self.wrapping_floor_div(rhs) + } + } +} + +macro_rules! impl_signed_pl_num_arith { + ($T:ty) => { + impl PlNumArithmetic for $T { + type TrueDivT = f64; + + #[inline(always)] + fn wrapping_abs(self) -> Self { + self.wrapping_abs() + } + + #[inline(always)] + fn wrapping_neg(self) -> Self { + self.wrapping_neg() + } + + #[inline(always)] + fn wrapping_add(self, rhs: Self) -> Self { + self.wrapping_add(rhs) + } + + #[inline(always)] + fn wrapping_sub(self, rhs: Self) -> Self { + self.wrapping_sub(rhs) + } + + #[inline(always)] + fn wrapping_mul(self, rhs: Self) -> Self { + self.wrapping_mul(rhs) + } + + #[inline(always)] + fn wrapping_floor_div(self, rhs: Self) -> Self { + self.wrapping_floor_div_mod(rhs).0 + } + + #[inline(always)] + fn wrapping_trunc_div(self, rhs: Self) -> Self { + if rhs != 0 { + self.wrapping_div(rhs) + } else { + 0 + } + } + + #[inline(always)] + fn wrapping_mod(self, rhs: Self) -> Self { + self.wrapping_floor_div_mod(rhs).1 + } + + #[inline(always)] + fn true_div(self, rhs: Self) -> Self::TrueDivT { + self as f64 / rhs as f64 + } + } + }; +} + +impl_signed_pl_num_arith!(i8); +impl_signed_pl_num_arith!(i16); +impl_signed_pl_num_arith!(i32); +impl_signed_pl_num_arith!(i64); +impl_signed_pl_num_arith!(i128); + +macro_rules! impl_unsigned_pl_num_arith { + ($T:ty) => { + impl PlNumArithmetic for $T { + type TrueDivT = f64; + + #[inline(always)] + fn wrapping_abs(self) -> Self { + self + } + + #[inline(always)] + fn wrapping_neg(self) -> Self { + self.wrapping_neg() + } + + #[inline(always)] + fn wrapping_add(self, rhs: Self) -> Self { + self.wrapping_add(rhs) + } + + #[inline(always)] + fn wrapping_sub(self, rhs: Self) -> Self { + self.wrapping_sub(rhs) + } + + #[inline(always)] + fn wrapping_mul(self, rhs: Self) -> Self { + self.wrapping_mul(rhs) + } + + #[inline(always)] + fn wrapping_floor_div(self, rhs: Self) -> Self { + if rhs != 0 { + self / rhs + } else { + 0 + } + } + + #[inline(always)] + fn wrapping_trunc_div(self, rhs: Self) -> Self { + self.wrapping_floor_div(rhs) + } + + #[inline(always)] + fn wrapping_mod(self, rhs: Self) -> Self { + if rhs != 0 { + self % rhs + } else { + 0 + } + } + + #[inline(always)] + fn true_div(self, rhs: Self) -> Self::TrueDivT { + self as f64 / rhs as f64 + } + } + }; +} + +impl_unsigned_pl_num_arith!(u8); +impl_unsigned_pl_num_arith!(u16); +impl_unsigned_pl_num_arith!(u32); +impl_unsigned_pl_num_arith!(u64); +impl_unsigned_pl_num_arith!(u128); + +macro_rules! impl_float_pl_num_arith { + ($T:ty) => { + impl PlNumArithmetic for $T { + type TrueDivT = $T; + + #[inline(always)] + fn wrapping_abs(self) -> Self { + self.abs() + } + + #[inline(always)] + fn wrapping_neg(self) -> Self { + -self + } + + #[inline(always)] + fn wrapping_add(self, rhs: Self) -> Self { + self + rhs + } + + #[inline(always)] + fn wrapping_sub(self, rhs: Self) -> Self { + self - rhs + } + + #[inline(always)] + fn wrapping_mul(self, rhs: Self) -> Self { + self * rhs + } + + #[inline(always)] + fn wrapping_floor_div(self, rhs: Self) -> Self { + let l = self; + let r = rhs; + (l / r).floor() + } + + #[inline(always)] + fn wrapping_trunc_div(self, rhs: Self) -> Self { + let l = self; + let r = rhs; + (l / r).trunc() + } + + #[inline(always)] + fn wrapping_mod(self, rhs: Self) -> Self { + let l = self; + let r = rhs; + l - r * (l / r).floor() + } + + #[inline(always)] + fn true_div(self, rhs: Self) -> Self::TrueDivT { + self / rhs + } + } + }; +} + +impl_float_pl_num_arith!(f32); +impl_float_pl_num_arith!(f64); diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 2e613ea7e1a0..f9e5ff42139b 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -146,7 +146,11 @@ fn broadcast_array(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult<(ArrayChunk }, (a, b) if a == b => (lhs.clone(), rhs.clone()), _ => { - polars_bail!(InvalidOperation: "can only do arithmetic of arrays of the same type and shape; got {} and {}", lhs.dtype(), rhs.dtype()) + polars_bail!( + InvalidOperation: + "can only do arithmetic of arrays of the same type and shape; got {} and {}", + lhs.dtype(), rhs.dtype() + ) }, }; Ok(out) @@ -392,23 +396,35 @@ pub(crate) fn coerce_lhs_rhs<'a>( if let Some(result) = coerce_time_units(lhs, rhs) { return Ok(result); } - let dtype = match (lhs.dtype(), rhs.dtype()) { + let (left_dtype, right_dtype) = (lhs.dtype(), rhs.dtype()); + let leaf_super_dtype = match (left_dtype, right_dtype) { #[cfg(feature = "dtype-struct")] (DataType::Struct(_), DataType::Struct(_)) => { return Ok((Cow::Borrowed(lhs), Cow::Borrowed(rhs))) }, - _ => try_get_supertype(lhs.dtype(), rhs.dtype())?, + _ => try_get_supertype(left_dtype.leaf_dtype(), right_dtype.leaf_dtype())?, }; - let left = if lhs.dtype() == &dtype { + let mut new_left_dtype = left_dtype.cast_leaf(leaf_super_dtype.clone()); + let mut new_right_dtype = right_dtype.cast_leaf(leaf_super_dtype); + + // Cast List<->Array to List + if (left_dtype.is_list() && right_dtype.is_array()) + || (left_dtype.is_array() && right_dtype.is_list()) + { + new_left_dtype = try_get_supertype(&new_left_dtype, &new_right_dtype)?; + new_right_dtype = new_left_dtype.clone(); + } + + let left = if lhs.dtype() == &new_left_dtype { Cow::Borrowed(lhs) } else { - Cow::Owned(lhs.cast(&dtype)?) + Cow::Owned(lhs.cast(&new_left_dtype)?) }; - let right = if rhs.dtype() == &dtype { + let right = if rhs.dtype() == &new_right_dtype { Cow::Borrowed(rhs) } else { - Cow::Owned(rhs.cast(&dtype)?) + Cow::Owned(rhs.cast(&new_right_dtype)?) }; Ok((left, right)) } @@ -522,6 +538,9 @@ impl Add for &Series { (DataType::Struct(_), DataType::Struct(_)) => { _struct_arithmetic(self, rhs, |a, b| a.add(b)) }, + (DataType::List(_), _) | (_, DataType::List(_)) => { + list_borrowed::NumericListOp::Add.execute(self, rhs) + }, _ => { let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; lhs.add_to(rhs.as_ref()) @@ -540,6 +559,9 @@ impl Sub for &Series { (DataType::Struct(_), DataType::Struct(_)) => { _struct_arithmetic(self, rhs, |a, b| a.sub(b)) }, + (DataType::List(_), _) | (_, DataType::List(_)) => { + list_borrowed::NumericListOp::Sub.execute(self, rhs) + }, _ => { let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; lhs.subtract(rhs.as_ref()) @@ -574,6 +596,9 @@ impl Mul for &Series { let out = rhs.multiply(self)?; Ok(out.with_name(self.name().clone())) }, + (DataType::List(_), _) | (_, DataType::List(_)) => { + list_borrowed::NumericListOp::Mul.execute(self, rhs) + }, _ => { let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; lhs.multiply(rhs.as_ref()) @@ -595,19 +620,18 @@ impl Div for &Series { use DataType::*; match (self.dtype(), rhs.dtype()) { #[cfg(feature = "dtype-struct")] - (Struct(_), Struct(_)) => { - _struct_arithmetic(self, rhs, |a, b| a.div(b)) - }, + (Struct(_), Struct(_)) => _struct_arithmetic(self, rhs, |a, b| a.div(b)), (Duration(_), _) => self.divide(rhs), - | (Date, _) + (Date, _) | (Datetime(_, _), _) | (Time, _) - // temporal rhs - | (_ , Duration(_)) - | (_ , Time) - | (_ , Date) - | (_ , Datetime(_, _)) - => polars_bail!(opq = div, self.dtype(), rhs.dtype()), + | (_, Duration(_)) + | (_, Time) + | (_, Date) + | (_, Datetime(_, _)) => polars_bail!(opq = div, self.dtype(), rhs.dtype()), + (DataType::List(_), _) | (_, DataType::List(_)) => { + list_borrowed::NumericListOp::Div.execute(self, rhs) + }, _ => { let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; lhs.divide(rhs.as_ref()) @@ -631,6 +655,9 @@ impl Rem for &Series { (DataType::Struct(_), DataType::Struct(_)) => { _struct_arithmetic(self, rhs, |a, b| a.rem(b)) }, + (DataType::List(_), _) | (_, DataType::List(_)) => { + list_borrowed::NumericListOp::Rem.execute(self, rhs) + }, _ => { let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; lhs.remainder(rhs.as_ref()) diff --git a/crates/polars-core/src/series/arithmetic/list_borrowed.rs b/crates/polars-core/src/series/arithmetic/list_borrowed.rs index 1628780d7b0e..f59492f71cb9 100644 --- a/crates/polars-core/src/series/arithmetic/list_borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/list_borrowed.rs @@ -1,177 +1,1026 @@ //! Allow arithmetic operations for ListChunked. +use arrow::bitmap::Bitmap; +use arrow::compute::utils::combine_validities_and; +use arrow::offset::OffsetsBuffer; +use either::Either; +use num_traits::Zero; +use polars_compute::arithmetic::pl_num::PlNumArithmetic; +use polars_compute::arithmetic::ArithmeticKernel; +use polars_compute::comparisons::TotalEqKernel; +use polars_utils::float::IsFloat; + use super::*; -use crate::chunked_array::builder::AnonymousListBuilder; - -/// Given an ArrayRef with some primitive values, wrap it in list(s) until it -/// matches the requested shape. -fn reshape_list_based_on(data: &ArrayRef, shape: &ArrayRef) -> ArrayRef { - if let Some(list_chunk) = shape.as_any().downcast_ref::() { - let result = LargeListArray::new( - list_chunk.dtype().clone(), - list_chunk.offsets().clone(), - reshape_list_based_on(data, list_chunk.values()), - list_chunk.validity().cloned(), - ); - Box::new(result) - } else { - data.clone() + +impl NumOpsDispatchInner for ListType { + fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + NumericListOp::Add.execute(&lhs.clone().into_series(), rhs) + } + + fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + NumericListOp::Sub.execute(&lhs.clone().into_series(), rhs) + } + + fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + NumericListOp::Mul.execute(&lhs.clone().into_series(), rhs) + } + + fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + NumericListOp::Div.execute(&lhs.clone().into_series(), rhs) + } + + fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + NumericListOp::Rem.execute(&lhs.clone().into_series(), rhs) } } -/// Given an ArrayRef, return true if it's a LargeListArrays and it has one or -/// more nulls. -fn does_list_have_nulls(data: &ArrayRef) -> bool { - if let Some(list_chunk) = data.as_any().downcast_ref::() { - if list_chunk - .validity() - .map(|bitmap| bitmap.unset_bits() > 0) - .unwrap_or(false) - { - true - } else { - does_list_have_nulls(list_chunk.values()) +#[derive(Debug, Clone)] +pub enum NumericListOp { + Add, + Sub, + Mul, + Div, + Rem, + FloorDiv, +} + +impl NumericListOp { + fn name(&self) -> &'static str { + match self { + Self::Add => "add", + Self::Sub => "sub", + Self::Mul => "mul", + Self::Div => "div", + Self::Rem => "rem", + Self::FloorDiv => "floor_div", } - } else { - false } -} -/// Return whether the left and right have the same shape. We assume neither has -/// any nulls, recursively. -fn lists_same_shapes(left: &ArrayRef, right: &ArrayRef) -> bool { - debug_assert!(!does_list_have_nulls(left)); - debug_assert!(!does_list_have_nulls(right)); - let left_as_list = left.as_any().downcast_ref::(); - let right_as_list = right.as_any().downcast_ref::(); - match (left_as_list, right_as_list) { - (Some(left), Some(right)) => { - left.offsets() == right.offsets() && lists_same_shapes(left.values(), right.values()) - }, - (None, None) => left.len() == right.len(), - _ => false, + pub fn try_get_leaf_supertype( + &self, + prim_dtype_lhs: &DataType, + prim_dtype_rhs: &DataType, + ) -> PolarsResult { + let dtype = try_get_supertype(prim_dtype_lhs, prim_dtype_rhs)?; + + Ok(if matches!(self, Self::Div) { + if dtype.is_float() { + dtype + } else { + DataType::Float64 + } + } else { + dtype + }) + } + + pub fn execute(&self, lhs: &Series, rhs: &Series) -> PolarsResult { + // Ideally we only need to rechunk the leaf array, but getting the + // list offsets of a ListChunked triggers a rechunk anyway, so we just + // do it here. + let lhs = lhs.rechunk(); + let rhs = rhs.rechunk(); + + let binary_op_exec = match BinaryListNumericOpHelper::try_new( + self.clone(), + lhs.name().clone(), + lhs.dtype(), + rhs.dtype(), + lhs.len(), + rhs.len(), + { + let (a, b) = lhs.list_offsets_and_validities_recursive(); + (a, b, lhs.clone()) + }, + { + let (a, b) = rhs.list_offsets_and_validities_recursive(); + (a, b, rhs.clone()) + }, + lhs.rechunk_validity(), + rhs.rechunk_validity(), + )? { + Either::Left(v) => v, + Either::Right(ca) => return Ok(ca.into_series()), + }; + + Ok(binary_op_exec.finish()?.into_series()) } -} -impl ListChunked { - /// Helper function for NumOpsDispatchInner implementation for ListChunked. - /// - /// Run the given `op` on `self` and `rhs`. - fn arithm_helper( + /// For operations that perform divisions on integers, sets the validity to NULL on rows where + /// the denominator is 0. + fn prepare_numeric_op_side_validities( &self, - rhs: &Series, - op: &dyn Fn(&Series, &Series) -> PolarsResult, - has_nulls: Option, - ) -> PolarsResult { - polars_ensure!( - self.len() == rhs.len(), - InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", - self.len(), - rhs.len() - ); - - let mut has_nulls = has_nulls.unwrap_or(false); - if !has_nulls { - for chunk in self.chunks().iter() { - if does_list_have_nulls(chunk) { - has_nulls = true; - break; - } + lhs: &mut PrimitiveArray, + rhs: &mut PrimitiveArray, + swapped: bool, + ) where + PrimitiveArray: polars_compute::comparisons::TotalEqKernel, + T::Native: Zero + IsFloat, + { + if !T::Native::is_float() { + match self { + Self::Div | Self::Rem | Self::FloorDiv => { + let target = if swapped { lhs } else { rhs }; + let ne_0 = target.tot_ne_kernel_broadcast(&T::Native::zero()); + let validity = combine_validities_and(target.validity(), Some(&ne_0)); + target.set_validity(validity); + }, + _ => {}, } } - if !has_nulls { - for chunk in rhs.chunks().iter() { - if does_list_have_nulls(chunk) { - has_nulls = true; - break; + } + + /// For list<->primitive where the primitive is broadcasted, we can dispatch to + /// `ArithmeticKernel`, which can have optimized codepaths for when one side is + /// a scalar. + fn apply_array_to_scalar( + &self, + arr_lhs: PrimitiveArray, + r: T::Native, + swapped: bool, + ) -> PrimitiveArray { + match self { + Self::Add => ArithmeticKernel::wrapping_add_scalar(arr_lhs, r), + Self::Sub => { + if swapped { + ArithmeticKernel::wrapping_sub_scalar_lhs(r, arr_lhs) + } else { + ArithmeticKernel::wrapping_sub_scalar(arr_lhs, r) } - } + }, + Self::Mul => ArithmeticKernel::wrapping_mul_scalar(arr_lhs, r), + Self::Div => { + if swapped { + ArithmeticKernel::legacy_div_scalar_lhs(r, arr_lhs) + } else { + ArithmeticKernel::legacy_div_scalar(arr_lhs, r) + } + }, + Self::Rem => { + if swapped { + ArithmeticKernel::wrapping_mod_scalar_lhs(r, arr_lhs) + } else { + ArithmeticKernel::wrapping_mod_scalar(arr_lhs, r) + } + }, + Self::FloorDiv => { + if swapped { + ArithmeticKernel::wrapping_floor_div_scalar_lhs(r, arr_lhs) + } else { + ArithmeticKernel::wrapping_floor_div_scalar(arr_lhs, r) + } + }, } - if has_nulls { - // A slower implementation since we can't just add the underlying - // values Arrow arrays. Given nulls, the two values arrays might not - // line up the way we expect. - let mut result = AnonymousListBuilder::new( - self.name().clone(), - self.len(), - Some(self.inner_dtype().clone()), - ); - let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| { - let (Some(a_owner), Some(b_owner)) = (a, b) else { - // Operations with nulls always result in nulls: - return Ok(None); - }; - let a = a_owner.as_ref(); - let b = b_owner.as_ref(); - polars_ensure!( - a.len() == b.len(), - InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}", - a.len(), - b.len() - ); - let chunk_result = if let Ok(a_listchunked) = a.list() { - // If `a` contains more lists, we're going to reach this - // function recursively, and again have to decide whether to - // use the fast path (no nulls) or slow path (there were - // nulls). Since we know there were nulls, that means we - // have to stick to the slow path, so pass that information - // along. - a_listchunked.arithm_helper(b, op, Some(true)) - } else { - op(a, b) - }; - chunk_result.map(Some) - }).collect::>>>()?; - for s in combined.iter() { - if let Some(s) = s { - result.append_series(s)?; + } +} + +macro_rules! with_match_numeric_list_op { + ($op:expr, $swapped:expr, | $_:tt $OP:tt | $($body:tt)* ) => ({ + macro_rules! __with_func__ {( $_ $OP:tt ) => ( $($body)* )} + + match $op { + NumericListOp::Add => __with_func__! { (PlNumArithmetic::wrapping_add) }, + NumericListOp::Sub => { + if $swapped { + __with_func__! { (|b, a| PlNumArithmetic::wrapping_sub(a, b)) } } else { - result.append_null(); + __with_func__! { (PlNumArithmetic::wrapping_sub) } } - } - return Ok(result.finish().into()); + }, + NumericListOp::Mul => __with_func__! { (PlNumArithmetic::wrapping_mul) }, + NumericListOp::Div => { + if $swapped { + __with_func__! { (|b, a| PlNumArithmetic::legacy_div(a, b)) } + } else { + __with_func__! { (PlNumArithmetic::legacy_div) } + } + }, + NumericListOp::Rem => { + if $swapped { + __with_func__! { (|b, a| PlNumArithmetic::wrapping_mod(a, b)) } + } else { + __with_func__! { (PlNumArithmetic::wrapping_mod) } + } + }, + NumericListOp::FloorDiv => { + if $swapped { + __with_func__! { (|b, a| PlNumArithmetic::wrapping_floor_div(a, b)) } + } else { + __with_func__! { (PlNumArithmetic::wrapping_floor_div) } + } + }, } - let l_rechunked = self.clone().rechunk().into_series(); - let l_leaf_array = l_rechunked.get_leaf_array(); - let r_leaf_array = rhs.rechunk().get_leaf_array(); - polars_ensure!( - lists_same_shapes(&l_leaf_array.chunks()[0], &r_leaf_array.chunks()[0]), - InvalidOperation: "can only do arithmetic operations on lists of the same size" - ); - - let result = op(&l_leaf_array, &r_leaf_array)?; - - // We now need to wrap the Arrow arrays with the metadata that turns - // them into lists: - // TODO is there a way to do this without cloning the underlying data? - let result_chunks = result.chunks(); - assert_eq!(result_chunks.len(), 1); - let left_chunk = &l_rechunked.chunks()[0]; - let result_chunk = reshape_list_based_on(&result_chunks[0], left_chunk); - - unsafe { - let mut result = - ListChunked::new_with_dims(self.field.clone(), vec![result_chunk], 0, 0); - result.compute_len(); - Ok(result.into()) + }) +} + +#[derive(Debug)] +enum BinaryOpApplyType { + ListToList, + ListToPrimitive, + PrimitiveToList, +} + +#[derive(Debug)] +enum Broadcast { + Left, + Right, + #[allow(clippy::enum_variant_names)] + NoBroadcast, +} + +/// Utility to perform a binary operation between the primitive values of +/// 2 columns, where at least one of the columns is a `ListChunked` type. +struct BinaryListNumericOpHelper { + op: NumericListOp, + output_name: PlSmallStr, + op_apply_type: BinaryOpApplyType, + broadcast: Broadcast, + output_dtype: DataType, + output_primitive_dtype: DataType, + output_len: usize, + /// Outer validity of the result, we always materialize this to reduce the + /// amount of code paths we need. + outer_validity: Bitmap, + // The series are stored as they are used for list broadcasting. + data_lhs: (Vec>, Vec>, Series), + data_rhs: (Vec>, Vec>, Series), + list_to_prim_lhs: Option<(Box, usize)>, + swapped: bool, +} + +/// This lets us separate some logic into `new()` to reduce the amount of +/// monomorphized code. +impl BinaryListNumericOpHelper { + /// Checks that: + /// * Dtypes are compatible: + /// * list<->primitive | primitive<->list + /// * list<->list both contain primitives (e.g. List) + /// * Primitive dtypes match + /// * Lengths are compatible: + /// * 1<->n | n<->1 + /// * n<->n + /// * Both sides have at least 1 non-NULL outer row. + /// + /// Does not check: + /// * Whether the offsets are aligned for list<->list, this will be checked during execution. + /// + /// This returns an `Either` which may contain the final result to simplify + /// the implementation. + #[allow(clippy::too_many_arguments)] + fn try_new( + op: NumericListOp, + output_name: PlSmallStr, + dtype_lhs: &DataType, + dtype_rhs: &DataType, + len_lhs: usize, + len_rhs: usize, + data_lhs: (Vec>, Vec>, Series), + data_rhs: (Vec>, Vec>, Series), + validity_lhs: Option, + validity_rhs: Option, + ) -> PolarsResult> { + let prim_dtype_lhs = dtype_lhs.leaf_dtype(); + let prim_dtype_rhs = dtype_rhs.leaf_dtype(); + + let output_primitive_dtype = op.try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)?; + + let (op_apply_type, output_dtype) = match (dtype_lhs, dtype_rhs) { + (l @ DataType::List(a), r @ DataType::List(b)) => { + // `get_arithmetic_field()` in the DSL checks this, but we also have to check here because if a user + // directly adds 2 series together it bypasses the DSL. + // This is currently duplicated code and should be replaced one day with an assert after Series ops get + // checked properly. + if ![a, b] + .into_iter() + .all(|x| x.is_numeric() || x.is_bool() || x.is_null()) + { + polars_bail!( + InvalidOperation: + "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})", + op.name(), l, r, + ); + } + (BinaryOpApplyType::ListToList, l) + }, + (list_dtype @ DataType::List(_), x) if x.is_numeric() || x.is_bool() || x.is_null() => { + (BinaryOpApplyType::ListToPrimitive, list_dtype) + }, + (x, list_dtype @ DataType::List(_)) if x.is_numeric() || x.is_bool() || x.is_null() => { + (BinaryOpApplyType::PrimitiveToList, list_dtype) + }, + (l, r) => polars_bail!( + InvalidOperation: + "{} operation not supported for dtypes: {} != {}", + op.name(), l, r, + ), + }; + + let output_dtype = output_dtype.cast_leaf(output_primitive_dtype.clone()); + + let (broadcast, output_len) = match (len_lhs, len_rhs) { + (l, r) if l == r => (Broadcast::NoBroadcast, l), + (1, v) => (Broadcast::Left, v), + (v, 1) => (Broadcast::Right, v), + (l, r) => polars_bail!( + ShapeMismatch: + "cannot {} two columns of differing lengths: {} != {}", + op.name(), l, r + ), + }; + + let DataType::List(output_inner_dtype) = &output_dtype else { + unreachable!() + }; + + // # NULL semantics + // * [[1, 2]] (List[List[Int64]]) + NULL (Int64) => [[NULL, NULL]] + // * Essentially as if the NULL primitive was added to every primitive in the row of the list column. + // * NULL (List[Int64]) + 1 (Int64) => NULL + // * NULL (List[Int64]) + [1] (List[Int64]) => NULL + + if output_len == 0 + || (len_lhs == 1 + && matches!( + &op_apply_type, + BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive + ) + && validity_lhs.as_ref().map_or(false, |x| { + !x.get_bit(0) // is not valid + })) + || (len_rhs == 1 + && matches!( + &op_apply_type, + BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList + ) + && validity_rhs.as_ref().map_or(false, |x| { + !x.get_bit(0) // is not valid + })) + { + return Ok(Either::Right(ListChunked::full_null_with_dtype( + output_name.clone(), + output_len, + output_inner_dtype.as_ref(), + ))); } + + // At this point: + // * All unit length list columns have a valid outer value. + + // The outer validity is just the validity of any non-broadcasting lists. + let outer_validity = match (&op_apply_type, &broadcast, validity_lhs, validity_rhs) { + // Both lists with same length, we combine the validity. + (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast, l, r) => { + combine_validities_and(l.as_ref(), r.as_ref()) + }, + // Match all other combinations that have non-broadcasting lists. + ( + BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive, + Broadcast::NoBroadcast | Broadcast::Right, + v, + _, + ) + | ( + BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList, + Broadcast::NoBroadcast | Broadcast::Left, + _, + v, + ) => v, + _ => None, + } + .unwrap_or_else(|| Bitmap::new_with_value(true, output_len)); + + Ok(Either::Left(Self { + op, + output_name, + op_apply_type, + broadcast, + output_dtype: output_dtype.clone(), + output_primitive_dtype, + output_len, + outer_validity, + data_lhs, + data_rhs, + list_to_prim_lhs: None, + swapped: false, + })) } -} -impl NumOpsDispatchInner for ListType { - fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.add_to(r), None) + fn finish(mut self) -> PolarsResult { + // We have physical codepaths for a subset of the possible combinations of broadcasting and + // column types. The remaining combinations are handled by dispatching to the physical + // codepaths after operand swapping and/or materialized broadcasting. + // + // # Physical impl table + // Legend + // * | N | // impl "N" + // * | [N] | // dispatches to impl "N" + // + // | L | N | R | // Broadcast (L)eft, (N)oBroadcast, (R)ight + // ListToList | [1] | 0 | 1 | + // ListToPrimitive | [2] | 2 | 3 | // list broadcasting just materializes and dispatches to NoBroadcast + // PrimitiveToList | [3] | [2] | [2] | + + self.swapped = true; + + match (&self.op_apply_type, &self.broadcast) { + (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) + | (BinaryOpApplyType::ListToList, Broadcast::Right) + | (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) + | (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => { + self.swapped = false; + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) => { + // We materialize the list columns with `new_from_index`, as otherwise we'd have to + // implement logic that broadcasts the offsets and validities across multiple levels + // of nesting. But we will re-use the materialized memory to store the result. + + self.list_to_prim_lhs + .replace(Self::materialize_broadcasted_list( + &mut self.data_rhs, + self.output_len, + &self.output_primitive_dtype, + )); + + self.op_apply_type = BinaryOpApplyType::ListToPrimitive; + self.broadcast = Broadcast::NoBroadcast; + core::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::ListToList, Broadcast::Left) => { + self.broadcast = Broadcast::Right; + + core::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) => { + self.list_to_prim_lhs + .replace(Self::materialize_broadcasted_list( + &mut self.data_lhs, + self.output_len, + &self.output_primitive_dtype, + )); + + self.broadcast = Broadcast::NoBroadcast; + + // This does not swap! We are just dispatching to `NoBroadcast` + // after materializing the broadcasted list array. + self.swapped = false; + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) => { + self.op_apply_type = BinaryOpApplyType::ListToPrimitive; + self.broadcast = Broadcast::Right; + + core::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => { + self.op_apply_type = BinaryOpApplyType::ListToPrimitive; + + core::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + + self._finish_impl_dispatch() + }, + } } - fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.subtract(r), None) + + fn _finish_impl_dispatch(&mut self) -> PolarsResult { + let output_dtype = self.output_dtype.clone(); + let output_len = self.output_len; + + let prim_lhs = self + .data_lhs + .2 + .get_leaf_array() + .cast(&self.output_primitive_dtype)? + .rechunk(); + let prim_rhs = self + .data_rhs + .2 + .get_leaf_array() + .cast(&self.output_primitive_dtype)? + .rechunk(); + + debug_assert_eq!(prim_lhs.dtype(), prim_rhs.dtype()); + let prim_dtype = prim_lhs.dtype(); + debug_assert_eq!(prim_dtype, &self.output_primitive_dtype); + + // Safety: Leaf dtypes have been checked to be numeric by `try_new()` + let out = with_match_physical_numeric_polars_type!(&prim_dtype, |$T| { + self._finish_impl::<$T>(prim_lhs, prim_rhs) + })?; + + debug_assert_eq!(out.dtype(), &output_dtype); + assert_eq!(out.len(), output_len); + + Ok(out) } - fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.multiply(r), None) + + /// Internal use only - contains physical impls. + fn _finish_impl( + &mut self, + prim_s_lhs: Series, + prim_s_rhs: Series, + ) -> PolarsResult + where + T::Native: PlNumArithmetic, + PrimitiveArray: polars_compute::comparisons::TotalEqKernel, + T::Native: Zero + IsFloat, + { + #[inline(never)] + fn check_mismatch_pos( + mismatch_pos: usize, + offsets_lhs: &OffsetsBuffer, + offsets_rhs: &OffsetsBuffer, + ) -> PolarsResult<()> { + if mismatch_pos < offsets_lhs.len_proxy() { + // RHS could be broadcasted + let len_r = offsets_rhs.length_at(if offsets_rhs.len_proxy() == 1 { + 0 + } else { + mismatch_pos + }); + polars_bail!( + ShapeMismatch: + "list lengths differed at index {}: {} != {}", + mismatch_pos, + offsets_lhs.length_at(mismatch_pos), len_r + ) + } + Ok(()) + } + + let mut arr_lhs = { + let ca: &ChunkedArray = prim_s_lhs.as_ref().as_ref(); + assert_eq!(ca.chunks().len(), 1); + ca.downcast_get(0).unwrap().clone() + }; + + let mut arr_rhs = { + let ca: &ChunkedArray = prim_s_rhs.as_ref().as_ref(); + assert_eq!(ca.chunks().len(), 1); + ca.downcast_get(0).unwrap().clone() + }; + + match (&self.op_apply_type, &self.broadcast) { + // We skip for this because it dispatches to `ArithmeticKernel`, which handles the + // validities for us. + (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {}, + _ if self.list_to_prim_lhs.is_none() => self + .op + .prepare_numeric_op_side_validities::(&mut arr_lhs, &mut arr_rhs, self.swapped), + (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {}, + _ => unreachable!(), + } + + // + // General notes + // * Lists can be: + // * Sliced, in which case the primitive/leaf array needs to be indexed starting from an + // offset instead of 0. + // * Masked, in which case the masked rows are permitted to have non-matching widths. + // + + let out = match (&self.op_apply_type, &self.broadcast) { + (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) => { + let offsets_lhs = &self.data_lhs.0[0]; + let offsets_rhs = &self.data_rhs.0[0]; + + assert_eq!(offsets_lhs.len_proxy(), offsets_rhs.len_proxy()); + + // Output primitive (and optional validity) are aligned to the LHS input. + let n_values = arr_lhs.len(); + let mut out_vec: Vec = Vec::with_capacity(n_values); + let out_ptr: *mut T::Native = out_vec.as_mut_ptr(); + + // Counter that stops being incremented at the first row position with mismatching + // list lengths. + let mut mismatch_pos = 0; + + with_match_numeric_list_op!(&self.op, self.swapped, |$OP| { + for (i, ((lhs_start, lhs_len), (rhs_start, rhs_len))) in offsets_lhs + .offset_and_length_iter() + .zip(offsets_rhs.offset_and_length_iter()) + .enumerate() + { + if + (mismatch_pos == i) + & ( + (lhs_len == rhs_len) + | unsafe { !self.outer_validity.get_bit_unchecked(i) } + ) + { + mismatch_pos += 1; + } + + // Both sides are lists, we restrict the index to the min length to avoid + // OOB memory access. + let len: usize = lhs_len.min(rhs_len); + + for i in 0..len { + let l_idx = i + lhs_start; + let r_idx = i + rhs_start; + + let l = unsafe { arr_lhs.value_unchecked(l_idx) }; + let r = unsafe { arr_rhs.value_unchecked(r_idx) }; + let v = $OP(l, r); + + unsafe { out_ptr.add(l_idx).write(v) }; + } + } + }); + + check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?; + + unsafe { out_vec.set_len(n_values) }; + + /// Reduce monomorphization + #[inline(never)] + fn combine_validities_list_to_list_no_broadcast( + offsets_lhs: &OffsetsBuffer, + offsets_rhs: &OffsetsBuffer, + validity_lhs: Option<&Bitmap>, + validity_rhs: Option<&Bitmap>, + len_lhs: usize, + ) -> Option { + match (validity_lhs, validity_rhs) { + (Some(l), Some(r)) => Some((l.clone().make_mut(), r)), + (Some(v), None) => return Some(v.clone()), + (None, Some(v)) => { + Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)) + }, + (None, None) => None, + } + .map(|(mut validity_out, validity_rhs)| { + for ((lhs_start, lhs_len), (rhs_start, rhs_len)) in offsets_lhs + .offset_and_length_iter() + .zip(offsets_rhs.offset_and_length_iter()) + { + let len: usize = lhs_len.min(rhs_len); + + for i in 0..len { + let l_idx = i + lhs_start; + let r_idx = i + rhs_start; + + let l_valid = unsafe { validity_out.get_unchecked(l_idx) }; + let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) }; + let is_valid = l_valid & r_valid; + + // Size and alignment of validity vec are based on LHS. + unsafe { validity_out.set_unchecked(l_idx, is_valid) }; + } + } + + validity_out.freeze() + }) + } + + let leaf_validity = combine_validities_list_to_list_no_broadcast( + offsets_lhs, + offsets_rhs, + arr_lhs.validity(), + arr_rhs.validity(), + arr_lhs.len(), + ); + + let arr = + PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); + + let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + assert_eq!(offsets.len(), 1); + + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + (BinaryOpApplyType::ListToList, Broadcast::Right) => { + let offsets_lhs = &self.data_lhs.0[0]; + let offsets_rhs = &self.data_rhs.0[0]; + + // Output primitive (and optional validity) are aligned to the LHS input. + let n_values = arr_lhs.len(); + let mut out_vec: Vec = Vec::with_capacity(n_values); + let out_ptr: *mut T::Native = out_vec.as_mut_ptr(); + + assert_eq!(offsets_rhs.len_proxy(), 1); + let rhs_start = *offsets_rhs.first() as usize; + let width = offsets_rhs.range() as usize; + + let mut mismatch_pos = 0; + + with_match_numeric_list_op!(&self.op, self.swapped, |$OP| { + for (i, (lhs_start, lhs_len)) in offsets_lhs.offset_and_length_iter().enumerate() { + if ((lhs_len == width) & (mismatch_pos == i)) + | unsafe { !self.outer_validity.get_bit_unchecked(i) } + { + mismatch_pos += 1; + } + + let len: usize = lhs_len.min(width); + + for i in 0..len { + let l_idx = i + lhs_start; + let r_idx = i + rhs_start; + + let l = unsafe { arr_lhs.value_unchecked(l_idx) }; + let r = unsafe { arr_rhs.value_unchecked(r_idx) }; + let v = $OP(l, r); + + unsafe { + out_ptr.add(l_idx).write(v); + } + } + } + }); + + check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?; + + unsafe { out_vec.set_len(n_values) }; + + #[inline(never)] + fn combine_validities_list_to_list_broadcast_right( + offsets_lhs: &OffsetsBuffer, + validity_lhs: Option<&Bitmap>, + validity_rhs: Option<&Bitmap>, + len_lhs: usize, + width: usize, + rhs_start: usize, + ) -> Option { + match (validity_lhs, validity_rhs) { + (Some(l), Some(r)) => Some((l.clone().make_mut(), r)), + (Some(v), None) => return Some(v.clone()), + (None, Some(v)) => { + Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)) + }, + (None, None) => None, + } + .map(|(mut validity_out, validity_rhs)| { + for (lhs_start, lhs_len) in offsets_lhs.offset_and_length_iter() { + let len: usize = lhs_len.min(width); + + for i in 0..len { + let l_idx = i + lhs_start; + let r_idx = i + rhs_start; + + let l_valid = unsafe { validity_out.get_unchecked(l_idx) }; + let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) }; + let is_valid = l_valid & r_valid; + + // Size and alignment of validity vec are based on LHS. + unsafe { validity_out.set_unchecked(l_idx, is_valid) }; + } + } + + validity_out.freeze() + }) + } + + let leaf_validity = combine_validities_list_to_list_broadcast_right( + offsets_lhs, + arr_lhs.validity(), + arr_rhs.validity(), + arr_lhs.len(), + width, + rhs_start, + ); + + let arr = + PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); + + let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + assert_eq!(offsets.len(), 1); + + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) + if self.list_to_prim_lhs.is_none() => + { + let offsets_lhs = self.data_lhs.0.as_slice(); + + // Notes + // * Primitive indexing starts from 0 + // * Output is aligned to LHS array + + let n_values = arr_lhs.len(); + let mut out_vec = Vec::::with_capacity(n_values); + let out_ptr = out_vec.as_mut_ptr(); + + with_match_numeric_list_op!(&self.op, self.swapped, |$OP| { + for (i, l_range) in OffsetsBuffer::::leaf_ranges_iter(offsets_lhs).enumerate() + { + let r = unsafe { arr_rhs.value_unchecked(i) }; + for l_idx in l_range { + unsafe { + let l = arr_lhs.value_unchecked(l_idx); + let v = $OP(l, r); + out_ptr.add(l_idx).write(v); + } + } + } + }); + + unsafe { out_vec.set_len(n_values) } + + let leaf_validity = combine_validities_list_to_primitive_no_broadcast( + offsets_lhs, + arr_lhs.validity(), + arr_rhs.validity(), + arr_lhs.len(), + ); + + let arr = + PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); + + let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + // If we are dispatched here, it means that the LHS array is a unique allocation created + // after a unit-length list column was broadcasted, so this codepath mutably stores the + // results back into the LHS array to save memory. + (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => { + let offsets_lhs = self.data_lhs.0.as_slice(); + + let (mut arr, n_values) = Option::take(&mut self.list_to_prim_lhs).unwrap(); + let arr = arr + .as_any_mut() + .downcast_mut::>() + .unwrap(); + let mut arr_lhs = core::mem::take(arr); + + self.op.prepare_numeric_op_side_validities::( + &mut arr_lhs, + &mut arr_rhs, + self.swapped, + ); + + let arr_lhs_mut_slice = arr_lhs.get_mut_values().unwrap(); + assert_eq!(arr_lhs_mut_slice.len(), n_values); + + with_match_numeric_list_op!(&self.op, self.swapped, |$OP| { + for (i, l_range) in OffsetsBuffer::::leaf_ranges_iter(offsets_lhs).enumerate() + { + let r = unsafe { arr_rhs.value_unchecked(i) }; + for l_idx in l_range { + unsafe { + let l = arr_lhs_mut_slice.get_unchecked_mut(l_idx); + *l = $OP(*l, r); + } + } + } + }); + + let leaf_validity = combine_validities_list_to_primitive_no_broadcast( + offsets_lhs, + arr_lhs.validity(), + arr_rhs.validity(), + arr_lhs.len(), + ); + + let arr = arr_lhs.with_validity(leaf_validity); + + let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => { + assert_eq!(arr_rhs.len(), 1); + + let Some(r) = (unsafe { arr_rhs.get_unchecked(0) }) else { + // RHS is single primitive NULL, create the result by setting the leaf validity to all-NULL. + let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + return self.finish_offsets_and_validities( + Box::new( + arr_lhs + .clone() + .with_validity(Some(Bitmap::new_with_value(false, arr_lhs.len()))), + ), + offsets, + validities, + ); + }; + + let arr = self.op.apply_array_to_scalar::(arr_lhs, r, self.swapped); + let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) + | v @ (BinaryOpApplyType::ListToList, Broadcast::Left) + | v @ (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) + | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) + | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => { + if cfg!(debug_assertions) { + panic!("operation was not re-written: {:?}", v) + } else { + unreachable!() + } + }, + }?; + + Ok(out) } - fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.divide(r), None) + + /// Construct the result `ListChunked` from the leaf array and the offsets/validities of every + /// level. + fn finish_offsets_and_validities( + &mut self, + leaf_array: Box, + offsets: Vec>, + validities: Vec>, + ) -> PolarsResult { + assert!(!offsets.is_empty()); + assert_eq!(offsets.len(), validities.len()); + let mut results = leaf_array; + + let mut iter = offsets.into_iter().zip(validities).rev(); + + while iter.len() > 1 { + let (offsets, validity) = iter.next().unwrap(); + let dtype = LargeListArray::default_datatype(results.dtype().clone()); + results = Box::new(LargeListArray::new(dtype, offsets, results, validity)); + } + + // The combined outer validity is pre-computed during `try_new()` + let (offsets, _) = iter.next().unwrap(); + let validity = core::mem::take(&mut self.outer_validity); + let dtype = LargeListArray::default_datatype(results.dtype().clone()); + let results = LargeListArray::new(dtype, offsets, results, Some(validity)); + + Ok(ListChunked::with_chunk( + core::mem::take(&mut self.output_name), + results, + )) } - fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.remainder(r), None) + + fn materialize_broadcasted_list( + side_data: &mut (Vec>, Vec>, Series), + output_len: usize, + output_primitive_dtype: &DataType, + ) -> (Box, usize) { + let s = &side_data.2; + assert_eq!(s.len(), 1); + + let expected_n_values = { + let offsets = s.list_offsets_and_validities_recursive().0; + output_len * OffsetsBuffer::::leaf_full_start_end(&offsets).len() + }; + + let ca = s.list().unwrap(); + // Remember to cast the leaf primitives to the supertype. + let ca = ca + .cast(&ca.dtype().cast_leaf(output_primitive_dtype.clone())) + .unwrap(); + assert!(output_len > 1); // In case there is a fast-path that doesn't give us owned data. + let ca = ca.new_from_index(0, output_len).rechunk(); + + let s = ca.into_series(); + + *side_data = { + let (a, b) = s.list_offsets_and_validities_recursive(); + // `Series::default()`: This field in the tuple is no longer used. + (a, b, Series::default()) + }; + + let n_values = OffsetsBuffer::::leaf_full_start_end(&side_data.0).len(); + assert_eq!(n_values, expected_n_values); + + let mut s = s.get_leaf_array(); + let v = unsafe { s.chunks_mut() }; + + assert_eq!(v.len(), 1); + (v.swap_remove(0), n_values) + } +} + +/// Used in 2 places, so it's outside here. +#[inline(never)] +fn combine_validities_list_to_primitive_no_broadcast( + offsets_lhs: &[OffsetsBuffer], + validity_lhs: Option<&Bitmap>, + validity_rhs: Option<&Bitmap>, + len_lhs: usize, +) -> Option { + match (validity_lhs, validity_rhs) { + (Some(l), Some(r)) => Some((l.clone().make_mut(), r)), + (Some(v), None) => return Some(v.clone()), + // Materialize a full-true validity to re-use the codepath, as we still + // need to spread the bits from the RHS to the correct positions. + (None, Some(v)) => Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)), + (None, None) => None, } + .map(|(mut validity_out, validity_rhs)| { + for (i, l_range) in OffsetsBuffer::::leaf_ranges_iter(offsets_lhs).enumerate() { + let r_valid = unsafe { validity_rhs.get_bit_unchecked(i) }; + for l_idx in l_range { + let l_valid = unsafe { validity_out.get_unchecked(l_idx) }; + let is_valid = l_valid & r_valid; + + // Size and alignment of validity vec are based on LHS. + unsafe { validity_out.set_unchecked(l_idx, is_valid) }; + } + } + + validity_out.freeze() + }) } diff --git a/crates/polars-core/src/series/arithmetic/mod.rs b/crates/polars-core/src/series/arithmetic/mod.rs index d7d7dbdb8a0e..0a5550b7b0f3 100644 --- a/crates/polars-core/src/series/arithmetic/mod.rs +++ b/crates/polars-core/src/series/arithmetic/mod.rs @@ -6,6 +6,7 @@ use std::borrow::Cow; use std::ops::{Add, Div, Mul, Rem, Sub}; pub use borrowed::*; +pub use list_borrowed::NumericListOp; use num_traits::{Num, NumCast}; use crate::prelude::*; diff --git a/crates/polars-core/src/series/ops/reshape.rs b/crates/polars-core/src/series/ops/reshape.rs index c0a57dbf3db7..642faafbfdf8 100644 --- a/crates/polars-core/src/series/ops/reshape.rs +++ b/crates/polars-core/src/series/ops/reshape.rs @@ -1,8 +1,9 @@ use std::borrow::Cow; use arrow::array::*; +use arrow::bitmap::Bitmap; use arrow::legacy::kernels::list::array_to_unit_list; -use arrow::offset::Offsets; +use arrow::offset::{Offsets, OffsetsBuffer}; use polars_error::{polars_bail, polars_ensure, PolarsResult}; use polars_utils::format_tuple; @@ -56,6 +57,25 @@ impl Series { } } + /// TODO: Move this somewhere else? + pub fn list_offsets_and_validities_recursive( + &self, + ) -> (Vec>, Vec>) { + let mut offsets = vec![]; + let mut validities = vec![]; + + let mut s = self.rechunk(); + + while let DataType::List(_) = s.dtype() { + let ca = s.list().unwrap(); + offsets.push(ca.offsets().unwrap()); + validities.push(ca.rechunk_validity()); + s = ca.get_inner(); + } + + (offsets, validities) + } + /// Convert the values of this Series to a ListChunked with a length of 1, /// so a Series of `[1, 2, 3]` becomes `[[1, 2, 3]]`. pub fn implode(&self) -> PolarsResult { diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index 2dc8de00dcd7..1ee69300fa92 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -2,6 +2,7 @@ use std::any::Any; use std::borrow::Cow; use std::sync::RwLockReadGuard; +use arrow::bitmap::{Bitmap, MutableBitmap}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -327,6 +328,26 @@ pub trait SeriesTrait: /// Aggregate all chunks to a contiguous array of memory. fn rechunk(&self) -> Series; + fn rechunk_validity(&self) -> Option { + if self.chunks().len() == 1 { + return self.chunks()[0].validity().cloned(); + } + + if !self.has_nulls() || self.is_empty() { + return None; + } + + let mut bm = MutableBitmap::with_capacity(self.len()); + for arr in self.chunks() { + if let Some(v) = arr.validity() { + bm.extend_from_bitmap(v); + } else { + bm.extend_constant(arr.len(), true); + } + } + Some(bm.into()) + } + /// Drop all null values and return a new Series. fn drop_nulls(&self) -> Series { if self.null_count() == 0 { diff --git a/crates/polars-expr/src/expressions/binary.rs b/crates/polars-expr/src/expressions/binary.rs index d0b00bf2ddac..c1cb286e7104 100644 --- a/crates/polars-expr/src/expressions/binary.rs +++ b/crates/polars-expr/src/expressions/binary.rs @@ -75,11 +75,8 @@ pub fn apply_operator(left: &Series, right: &Series, op: Operator) -> PolarsResu let right_dt = right.dtype().cast_leaf(Float64); left.cast(&left_dt)? / right.cast(&right_dt)? }, - dt @ List(_) => { - let left_dt = dt.cast_leaf(Float64); - let right_dt = right.dtype().cast_leaf(Float64); - left.cast(&left_dt)? / right.cast(&right_dt)? - }, + List(_) => left / right, + _ if right.dtype().is_list() => left / right, _ => { if right.dtype().is_temporal() { return left / right; diff --git a/crates/polars-ops/src/series/ops/floor_divide.rs b/crates/polars-ops/src/series/ops/floor_divide.rs index 4c5075ecad42..b8aa78c4ec01 100644 --- a/crates/polars-ops/src/series/ops/floor_divide.rs +++ b/crates/polars-ops/src/series/ops/floor_divide.rs @@ -1,6 +1,7 @@ use polars_compute::arithmetic::ArithmeticKernel; use polars_core::chunked_array::ops::arity::apply_binary_kernel_broadcast; use polars_core::prelude::*; +use polars_core::series::arithmetic::NumericListOp; #[cfg(feature = "dtype-struct")] use polars_core::series::arithmetic::_struct_arithmetic; use polars_core::with_match_physical_numeric_polars_type; @@ -24,6 +25,9 @@ pub fn floor_div_series(a: &Series, b: &Series) -> PolarsResult { (DataType::Struct(_), DataType::Struct(_)) => { return _struct_arithmetic(a, b, floor_div_series); }, + (DataType::List(_), _) | (_, DataType::List(_)) => { + return NumericListOp::FloorDiv.execute(a, b); + }, _ => {}, } diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index d74fd054a644..a290321f4cf8 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -370,6 +370,28 @@ fn get_arithmetic_field( (_, Time) | (Time, _) => { polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, + (l @ List(a), r @ List(b)) + if ![a, b] + .into_iter() + .all(|x| x.is_numeric() || x.is_bool() || x.is_null()) => + { + polars_bail!( + InvalidOperation: + "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})", + "sub", l, r, + ) + }, + (list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => { + // FIXME: This should not use `try_get_supertype()`! It should instead recursively use the enclosing match block. + // Otherwise we will silently permit addition operations between logical types (see above). + // This currently doesn't cause any problems because the list arithmetic implementation checks and raises errors + // if the leaf types aren't numeric, but it means we don't raise an error until execution and the DSL schema + // may be incorrect. + list_dtype.cast_leaf(try_get_supertype( + list_dtype.leaf_dtype(), + other_dtype.leaf_dtype(), + )?) + }, (left, right) => try_get_supertype(left, right)?, } }, @@ -395,6 +417,23 @@ fn get_arithmetic_field( polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, (Boolean, Boolean) => IDX_DTYPE, + (l @ List(a), r @ List(b)) + if ![a, b] + .into_iter() + .all(|x| x.is_numeric() || x.is_bool() || x.is_null()) => + { + polars_bail!( + InvalidOperation: + "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})", + "add", l, r, + ) + }, + (list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => { + list_dtype.cast_leaf(try_get_supertype( + list_dtype.leaf_dtype(), + other_dtype.leaf_dtype(), + )?) + }, (left, right) => try_get_supertype(left, right)?, } }, @@ -427,6 +466,27 @@ fn get_arithmetic_field( polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, }, + (l @ List(a), r @ List(b)) + if ![a, b] + .into_iter() + .all(|x| x.is_numeric() || x.is_bool() || x.is_null()) => + { + polars_bail!( + InvalidOperation: + "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})", + op, l, r, + ) + }, + // List<->primitive operations can be done directly after casting the to the primitive + // supertype for the primitive values on both sides. + (list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => { + let dtype = list_dtype.cast_leaf(try_get_supertype( + list_dtype.leaf_dtype(), + other_dtype.leaf_dtype(), + )?); + left_field.coerce(dtype); + return Ok(left_field); + }, _ => { // Avoid needlessly type casting numeric columns during arithmetic // with literals. @@ -466,32 +526,51 @@ fn get_truediv_field( nested: &mut u8, ) -> PolarsResult { let mut left_field = arena.get(left).to_field_impl(schema, arena, nested)?; + let right_field = arena.get(right).to_field_impl(schema, arena, nested)?; use DataType::*; - let out_type = match left_field.dtype() { - Float32 => Float32, - dt if dt.is_numeric() => Float64, - #[cfg(feature = "dtype-duration")] - Duration(_) => match arena - .get(right) - .to_field_impl(schema, arena, nested)? - .dtype() + + // TODO: Re-investigate this. A lot of "_" is being used on the RHS match because this code + // originally (mostly) only looked at the LHS dtype. + let out_type = match (left_field.dtype(), right_field.dtype()) { + (l @ List(a), r @ List(b)) + if ![a, b] + .into_iter() + .all(|x| x.is_numeric() || x.is_bool() || x.is_null()) => { - Duration(_) => Float64, - dt if dt.is_numeric() => return Ok(left_field), - dt => { - polars_bail!(InvalidOperation: "true division of {} with {} is not allowed", left_field.dtype(), dt) - }, + polars_bail!( + InvalidOperation: + "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})", + "div", l, r, + ) + }, + (list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => { + list_dtype.cast_leaf(match (list_dtype.leaf_dtype(), other_dtype.leaf_dtype()) { + (Float32, Float32) => Float32, + (Float32, Float64) | (Float64, Float32) => Float64, + // FIXME: We should properly recurse on the enclosing match block here. + (dt, _) => dt.clone(), + }) + }, + (Float32, _) => Float32, + (dt, _) if dt.is_numeric() => Float64, + #[cfg(feature = "dtype-duration")] + (Duration(_), Duration(_)) => Float64, + #[cfg(feature = "dtype-duration")] + (Duration(_), dt) if dt.is_numeric() => return Ok(left_field), + #[cfg(feature = "dtype-duration")] + (Duration(_), dt) => { + polars_bail!(InvalidOperation: "true division of {} with {} is not allowed", left_field.dtype(), dt) }, #[cfg(feature = "dtype-datetime")] - Datetime(_, _) => { + (Datetime(_, _), _) => { polars_bail!(InvalidOperation: "division of 'Datetime' datatype is not allowed") }, #[cfg(feature = "dtype-time")] - Time => polars_bail!(InvalidOperation: "division of 'Time' datatype is not allowed"), + (Time, _) => polars_bail!(InvalidOperation: "division of 'Time' datatype is not allowed"), #[cfg(feature = "dtype-date")] - Date => polars_bail!(InvalidOperation: "division of 'Date' datatype is not allowed"), + (Date, _) => polars_bail!(InvalidOperation: "division of 'Date' datatype is not allowed"), // we don't know what to do here, best return the dtype - dt => dt.clone(), + (dt, _) => dt.clone(), }; left_field.coerce(out_type); diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs index 37d58e004ab1..24f65b3465f5 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs @@ -47,55 +47,6 @@ fn is_cat_str_binary(type_left: &DataType, type_right: &DataType) -> bool { } } -fn process_list_arithmetic( - type_left: DataType, - type_right: DataType, - node_left: Node, - node_right: Node, - op: Operator, - expr_arena: &mut Arena, -) -> PolarsResult> { - match (&type_left, &type_right) { - (DataType::List(_), _) => { - let leaf = type_left.leaf_dtype(); - if type_right != *leaf { - let new_node_right = expr_arena.add(AExpr::Cast { - expr: node_right, - dtype: type_left.cast_leaf(leaf.clone()), - options: CastOptions::NonStrict, - }); - - Ok(Some(AExpr::BinaryExpr { - left: node_left, - op, - right: new_node_right, - })) - } else { - Ok(None) - } - }, - (_, DataType::List(_)) => { - let leaf = type_right.leaf_dtype(); - if type_left != *leaf { - let new_node_left = expr_arena.add(AExpr::Cast { - expr: node_left, - dtype: type_right.cast_leaf(leaf.clone()), - options: CastOptions::NonStrict, - }); - - Ok(Some(AExpr::BinaryExpr { - left: new_node_left, - op, - right: node_right, - })) - } else { - Ok(None) - } - }, - _ => unreachable!(), - } -} - #[cfg(feature = "dtype-struct")] // Ensure we don't cast to supertype // otherwise we will fill a struct with null fields @@ -265,11 +216,6 @@ pub(super) fn process_binary( (String, a) | (a, String) if a.is_numeric() => { polars_bail!(InvalidOperation: "arithmetic on string and numeric not allowed, try an explicit cast first") }, - (List(_), _) | (_, List(_)) => { - return process_list_arithmetic( - type_left, type_right, node_left, node_right, op, expr_arena, - ) - }, (Datetime(_, _), _) | (_, Datetime(_, _)) | (Date, _) @@ -277,7 +223,9 @@ pub(super) fn process_binary( | (Duration(_), _) | (_, Duration(_)) | (Time, _) - | (_, Time) => return Ok(None), + | (_, Time) + | (List(_), _) + | (_, List(_)) => return Ok(None), #[cfg(feature = "dtype-struct")] (Struct(_), a) | (a, Struct(_)) if a.is_numeric() => { return process_struct_numeric_arithmetic( diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 4bc24262b7c0..ae32d3b454e4 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -73,6 +73,7 @@ Float64, Int32, Int64, + Null, Object, String, Struct, @@ -1073,6 +1074,7 @@ def _div(self, other: Any, *, floordiv: bool) -> DataFrame: other = DataFrame([s.alias(f"n{i}") for i in range(len(self.columns))]) orig_dtypes = other.dtypes + # TODO: Dispatch to a native floordiv other = self._cast_all_from_to(other, INTEGER_DTYPES, Float64) df = self._from_pydf(self._df.div_df(other._df)) @@ -1085,7 +1087,8 @@ def _div(self, other: Any, *, floordiv: bool) -> DataFrame: int_casts = [ col(column).cast(tp) for i, (column, tp) in enumerate(self.schema.items()) - if tp.is_integer() and orig_dtypes[i].is_integer() + if tp.is_integer() + and (orig_dtypes[i].is_integer() or orig_dtypes[i] == Null) ] if int_casts: return df.with_columns(int_casts) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 6eeadb5dc080..ea37a64aa778 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -1084,13 +1084,18 @@ def __truediv__(self, other: Any) -> Series | Expr: msg = "first cast to integer before dividing datelike dtypes" raise TypeError(msg) - # this branch is exactly the floordiv function without rounding the floats - if self.dtype.is_float() or self.dtype == Decimal: - as_float = self - else: - as_float = self._recursive_cast_to_dtype(Float64()) + self = ( + self._recursive_cast_to_dtype(Float64()) + if not ( + self.dtype.is_float() + or self.dtype.is_decimal() + or isinstance(self.dtype, List) + or (isinstance(other, Series) and isinstance(other.dtype, List)) + ) + else self + ) - return as_float._arithmetic(other, "div", "div_<>") + return self._arithmetic(other, "div", "div_<>") @overload def __floordiv__(self, other: Expr) -> Expr: ... diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index a4c8e7786f1b..b04cb92b8889 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -21,7 +21,7 @@ UInt32, UInt64, ) -from polars.exceptions import ColumnNotFoundError, InvalidOperationError, SchemaError +from polars.exceptions import ColumnNotFoundError, InvalidOperationError, ShapeError from polars.testing import assert_frame_equal, assert_series_equal from tests.unit.conftest import INTEGER_DTYPES, NUMERIC_DTYPES @@ -284,8 +284,8 @@ def test_operator_arithmetic_with_nulls(op: Any, dtype: pl.DataType) -> None: df_expected, df.select(getattr(pl.col("n"), op_name)(null_expr)) ) - assert_frame_equal(df_expected, op(df, None)) - assert_series_equal(s_expected, op(s, None)) + assert_frame_equal(op(df, None), df_expected) + assert_series_equal(op(s, None), s_expected) @pytest.mark.parametrize( @@ -598,7 +598,6 @@ def test_array_arithmetic_same_size( pl.Series("nested", np.array([[[1, 2]], [[3, 4]]], dtype=np.int64)), ] ) - print(df.select(expr(pl.col(column_names[0]), pl.col(column_names[1])))) # Expr-based arithmetic: assert_frame_equal( df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))), @@ -624,16 +623,6 @@ def test_array_arithmetic_same_size( lambda a, b: a + b, ("a", "uint8"), ), - ( - [[[2, 4]], [[6]]], - lambda a, b: a + b, - ("nested", "nested"), - ), - ( - [[[2, 4]], [[6]]], - lambda a, b: a + b, - ("nested", "nested_uint8"), - ), ], ) def test_list_arithmetic_same_size( @@ -668,7 +657,6 @@ def test_list_arithmetic_same_size( [ ([[1, 2, 3]], [[1, None, 5]], [[2, None, 8]]), ([[2], None, [5]], [None, [3], [2]], [None, None, [7]]), - ([[[2]], [None], [[4]]], [[[3]], [[6]], [[8]]], [[[5]], [None], [[12]]]), ], ) def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) -> None: @@ -689,29 +677,29 @@ def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) def test_list_arithmetic_error_cases() -> None: # Different series length: - with pytest.raises( - InvalidOperationError, match="Series of the same size; got 1 and 2" - ): - _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1, 2], [3, 4]]) - with pytest.raises( - InvalidOperationError, match="Series of the same size; got 1 and 2" - ): - _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1, 2], None]) + with pytest.raises(InvalidOperationError, match="different lengths"): + _ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], [3, 4]]) + with pytest.raises(InvalidOperationError, match="different lengths"): + _ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], None]) # Different list length: - with pytest.raises(InvalidOperationError, match="lists of the same size"): - _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1]]) - with pytest.raises( - InvalidOperationError, match="lists of the same size; got 2 and 1" - ): + with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"): + _ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1]]) + + with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"): _ = pl.Series("a", [[1, 2], [2, 3]]) / pl.Series("b", [[1], None]) # Wrong types: - with pytest.raises(InvalidOperationError, match="cannot cast List type"): + with pytest.raises( + InvalidOperationError, match="add operation not supported for dtypes" + ): _ = pl.Series("a", [[1, 2]]) + pl.Series("b", ["hello"]) # Different nesting: - with pytest.raises(SchemaError, match="failed to determine supertype"): + with pytest.raises( + InvalidOperationError, + match="cannot add two list columns with non-numeric inner types", + ): _ = pl.Series("a", [[1]]) + pl.Series("b", [[[1]]]) @@ -891,7 +879,7 @@ def test_date_datetime_sub() -> None: def test_raise_invalid_shape() -> None: - with pytest.raises(pl.exceptions.InvalidOperationError): + with pytest.raises(InvalidOperationError): pl.DataFrame([[1, 2], [3, 4]]) * pl.DataFrame([1, 2, 3]) diff --git a/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py new file mode 100644 index 000000000000..c2a9a9186d31 --- /dev/null +++ b/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py @@ -0,0 +1,530 @@ +from __future__ import annotations + +import operator +from typing import Any, Callable + +import pytest + +import polars as pl +from polars.exceptions import InvalidOperationError, ShapeError +from polars.testing import assert_series_equal + + +def exec_op_with_series(lhs: pl.Series, rhs: pl.Series, op: Any) -> pl.Series: + v: pl.Series = op(lhs, rhs) + return v + + +def build_expr_op_exec( + type_coercion: bool, +) -> Callable[[pl.Series, pl.Series, Any], pl.Series]: + def func(lhs: pl.Series, rhs: pl.Series, op: Any) -> pl.Series: + return ( + pl.select(lhs) + .lazy() + .select(op(pl.first(), rhs)) + .collect(type_coercion=type_coercion) + .to_series() + ) + + return func + + +def build_series_broadcaster( + side: str, +) -> Callable[ + [pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series] +]: + length = 3 + + if side == "left": + + def func( + l: pl.Series, # noqa: E741 + r: pl.Series, + o: pl.Series, + ) -> tuple[pl.Series, pl.Series, pl.Series]: + return l.new_from_index(0, length), r, o.new_from_index(0, length) + elif side == "right": + + def func( + l: pl.Series, # noqa: E741 + r: pl.Series, + o: pl.Series, + ) -> tuple[pl.Series, pl.Series, pl.Series]: + return l, r.new_from_index(0, length), o.new_from_index(0, length) + elif side == "both": + + def func( + l: pl.Series, # noqa: E741 + r: pl.Series, + o: pl.Series, + ) -> tuple[pl.Series, pl.Series, pl.Series]: + return ( + l.new_from_index(0, length), + r.new_from_index(0, length), + o.new_from_index(0, length), + ) + elif side == "none": + + def func( + l: pl.Series, # noqa: E741 + r: pl.Series, + o: pl.Series, + ) -> tuple[pl.Series, pl.Series, pl.Series]: + return l, r, o + else: + raise ValueError(side) + + return func + + +BROADCAST_SERIES_COMBINATIONS = [ + build_series_broadcaster("left"), + build_series_broadcaster("right"), + build_series_broadcaster("both"), + build_series_broadcaster("none"), +] + +EXEC_OP_COMBINATIONS = [ + exec_op_with_series, + build_expr_op_exec(True), + build_expr_op_exec(False), +] + + +@pytest.mark.parametrize( + "list_side", ["left", "left3", "both", "right3", "right", "none"] +) +@pytest.mark.parametrize( + "broadcast_series", + BROADCAST_SERIES_COMBINATIONS, +) +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +def test_list_arithmetic_values( + list_side: str, + broadcast_series: Callable[ + [pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series] + ], + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + """ + Tests value correctness. + + This test checks for output value correctness (a + b == c) across different + codepaths, by wrapping the values (a, b, c) in different combinations of + list / primitive columns. + """ + import operator as op + + dtypes: list[Any] = [pl.Null, pl.Null, pl.Null] + dtype: Any = pl.Null + + def materialize_list(v: Any) -> pl.Series: + return pl.Series( + [[None, v, None]], + dtype=pl.List(dtype), + ) + + def materialize_list3(v: Any) -> pl.Series: + return pl.Series( + [[[[None, v], None], None]], + dtype=pl.List(pl.List(pl.List(dtype))), + ) + + def materialize_primitive(v: Any) -> pl.Series: + return pl.Series([v], dtype=dtype) + + def materialize_series( + l: Any, # noqa: E741 + r: Any, + o: Any, + ) -> tuple[pl.Series, pl.Series, pl.Series]: + nonlocal dtype + + dtype = dtypes[0] + l = { # noqa: E741 + "left": materialize_list, + "left3": materialize_list3, + "both": materialize_list, + "right": materialize_primitive, + "right3": materialize_primitive, + "none": materialize_primitive, + }[list_side](l) # fmt: skip + + dtype = dtypes[1] + r = { + "left": materialize_primitive, + "left3": materialize_primitive, + "both": materialize_list, + "right": materialize_list, + "right3": materialize_list3, + "none": materialize_primitive, + }[list_side](r) # fmt: skip + + dtype = dtypes[2] + o = { + "left": materialize_list, + "left3": materialize_list3, + "both": materialize_list, + "right": materialize_list, + "right3": materialize_list3, + "none": materialize_primitive, + }[list_side](o) # fmt: skip + + assert l.len() == 1 + assert r.len() == 1 + assert o.len() == 1 + + return broadcast_series(l, r, o) + + # Signed + dtypes = [pl.Int8, pl.Int8, pl.Int8] + + l, r, o = materialize_series(2, 3, 5) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + + l, r, o = materialize_series(-5, 127, 124) # noqa: E741 + assert_series_equal(exec_op(l, r, op.sub), o) + + l, r, o = materialize_series(-5, 127, -123) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mul), o) + + l, r, o = materialize_series(-5, 3, -2) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(-5, 3, 1) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.UInt8, pl.UInt8, pl.Float64] + l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Unsigned + dtypes = [pl.UInt8, pl.UInt8, pl.UInt8] + + l, r, o = materialize_series(2, 3, 5) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + + l, r, o = materialize_series(2, 3, 255) # noqa: E741 + assert_series_equal(exec_op(l, r, op.sub), o) + + l, r, o = materialize_series(2, 128, 0) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mul), o) + + l, r, o = materialize_series(5, 2, 2) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(5, 2, 1) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.UInt8, pl.UInt8, pl.Float64] + l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Floats. Note we pick Float32 to ensure there is no accidental upcasting + # to Float64. + dtypes = [pl.Float32, pl.Float32, pl.Float32] + l, r, o = materialize_series(1.7, 2.3, 4.0) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + + l, r, o = materialize_series(1.7, 2.3, -0.5999999999999999) # noqa: E741 + assert_series_equal(exec_op(l, r, op.sub), o) + + l, r, o = materialize_series(1.7, 2.3, 3.9099999999999997) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mul), o) + + l, r, o = materialize_series(7.0, 3.0, 2.0) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(-5.0, 3.0, 1.0) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + l, r, o = materialize_series(2.0, 128.0, 0.015625) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # + # Tests for zero behavior + # + + # Integer + + dtypes = [pl.UInt8, pl.UInt8, pl.UInt8] + + l, r, o = materialize_series(1, 0, None) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + assert_series_equal(exec_op(l, r, op.mod), o) + + l, r, o = materialize_series(0, 0, None) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.UInt8, pl.UInt8, pl.Float64] + + l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Float + + dtypes = [pl.Float32, pl.Float32, pl.Float32] + + l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(1, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # + # Tests for NULL behavior + # + + for dtype, truediv_dtype in [ # type: ignore[misc] + [pl.Int8, pl.Float64], + [pl.Float32, pl.Float32], + ]: + for vals in [ + [None, None, None], + [0, None, None], + [None, 0, None], + [0, None, None], + [None, 0, None], + [3, None, None], + [None, 3, None], + ]: + dtypes = 3 * [dtype] + + l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + assert_series_equal(exec_op(l, r, op.sub), o) + assert_series_equal(exec_op(l, r, op.mul), o) + assert_series_equal(exec_op(l, r, op.floordiv), o) + assert_series_equal(exec_op(l, r, op.mod), o) + dtypes[2] = truediv_dtype # type: ignore[has-type] + l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Type upcasting for Boolean and Null + + # Check boolean upcasting + dtypes = [pl.Boolean, pl.UInt8, pl.UInt8] + + l, r, o = materialize_series(True, 3, 4) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + + l, r, o = materialize_series(True, 3, 254) # noqa: E741 + assert_series_equal(exec_op(l, r, op.sub), o) + + l, r, o = materialize_series(True, 3, 3) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mul), o) + + l, r, o = materialize_series(True, 3, 0) # noqa: E741 + if list_side != "none": + # TODO: FIXME: We get an error on non-lists with this: + # "floor_div operation not supported for dtype `bool`" + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(True, 3, 1) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.Boolean, pl.UInt8, pl.Float64] + l, r, o = materialize_series(True, 128, 0.0078125) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Check Null upcasting + dtypes = [pl.Null, pl.UInt8, pl.UInt8] + l, r, o = materialize_series(None, 3, None) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + assert_series_equal(exec_op(l, r, op.sub), o) + assert_series_equal(exec_op(l, r, op.mul), o) + if list_side != "none": + assert_series_equal(exec_op(l, r, op.floordiv), o) + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.Null, pl.UInt8, pl.Float64] + l, r, o = materialize_series(None, 3, None) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +def test_list_add_supertype( + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + import operator as op + + a = pl.Series("a", [[1], [2]], dtype=pl.List(pl.Int8)) + b = pl.Series("b", [[1], [999]], dtype=pl.List(pl.Int64)) + + assert_series_equal( + exec_op(a, b, op.add), + pl.Series("a", [[2], [1001]], dtype=pl.List(pl.Int64)), + ) + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +@pytest.mark.parametrize( + "broadcast_series", + BROADCAST_SERIES_COMBINATIONS, +) +def test_list_numeric_op_validity_combination( + broadcast_series: Callable[ + [pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series] + ], + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + import operator as op + + a = pl.Series("a", [[1], [2], None, [None], [11], [1111]], dtype=pl.List(pl.Int32)) + b = pl.Series("b", [[1], [3], [11], [1111], None, [None]], dtype=pl.List(pl.Int64)) + # expected result + e = pl.Series("a", [[2], [5], None, [None], None, [None]], dtype=pl.List(pl.Int64)) + + assert_series_equal( + exec_op(a, b, op.add), + e, + ) + + a = pl.Series("a", [[1]], dtype=pl.List(pl.Int32)) + b = pl.Series("b", [None], dtype=pl.Int64) + e = pl.Series("a", [[None]], dtype=pl.List(pl.Int64)) + + a, b, e = broadcast_series(a, b, e) + assert_series_equal(exec_op(a, b, op.add), e) + + a = pl.Series("a", [None], dtype=pl.List(pl.Int32)) + b = pl.Series("b", [1], dtype=pl.Int64) + e = pl.Series("a", [None], dtype=pl.List(pl.Int64)) + + a, b, e = broadcast_series(a, b, e) + assert_series_equal(exec_op(a, b, op.add), e) + + a = pl.Series("a", [None], dtype=pl.List(pl.Int32)) + b = pl.Series("b", [0], dtype=pl.Int64) + e = pl.Series("a", [None], dtype=pl.List(pl.Int64)) + + a, b, e = broadcast_series(a, b, e) + assert_series_equal(exec_op(a, b, op.floordiv), e) + + +def test_list_add_alignment() -> None: + a = pl.Series("a", [[1, 1], [1, 1, 1]]) + b = pl.Series("b", [[1, 1, 1], [1, 1]]) + + df = pl.DataFrame([a, b]) + + with pytest.raises(ShapeError): + df.select(x=pl.col("a") + pl.col("b")) + + # Test masking and slicing + a = pl.Series("a", [[1, 1, 1], [1], [1, 1], [1, 1, 1]]) + b = pl.Series("b", [[1, 1], [1], [1, 1, 1], [1]]) + c = pl.Series("c", [1, 1, 1, 1]) + p = pl.Series("p", [True, True, False, False]) + + df = pl.DataFrame([a, b, c, p]).filter("p").slice(1) + + for rhs in [pl.col("b"), pl.lit(1), pl.col("c"), pl.lit([1])]: + assert_series_equal( + df.select(x=pl.col("a") + rhs).to_series(), pl.Series("x", [[2]]) + ) + + df = df.vstack(df) + + for rhs in [pl.col("b"), pl.lit(1), pl.col("c"), pl.lit([1])]: + assert_series_equal( + df.select(x=pl.col("a") + rhs).to_series(), pl.Series("x", [[2], [2]]) + ) + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +def test_list_add_empty_lists( + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + l = pl.Series( # noqa: E741 + "x", + [[[[]], []], []], + ) + r = pl.Series([1]) + + assert_series_equal( + exec_op(l, r, operator.add), + pl.Series("x", [[[[]], []], []], dtype=pl.List(pl.List(pl.List(pl.Int64)))), + ) + + l = pl.Series( # noqa: E741 + "x", + [[[[]], None], []], + ) + r = pl.Series([1]) + + assert_series_equal( + exec_op(l, r, operator.add), + pl.Series("x", [[[[]], None], []], dtype=pl.List(pl.List(pl.List(pl.Int64)))), + ) + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +def test_list_to_list_arithmetic_double_nesting_raises_error( + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + s = pl.Series(dtype=pl.List(pl.List(pl.Int32))) + + with pytest.raises( + InvalidOperationError, + match="cannot add two list columns with non-numeric inner types", + ): + exec_op(s, s, operator.add) + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +def test_list_add_height_mismatch( + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + s = pl.Series([[1], [2], [3]], dtype=pl.List(pl.Int32)) + + # TODO: Make the error type consistently a ShapeError + with pytest.raises( + (ShapeError, InvalidOperationError), + match="length", + ): + exec_op(s, pl.Series([1, 1]), operator.add) + + +@pytest.mark.parametrize( + "op", + [ + operator.add, + operator.sub, + operator.mul, + operator.floordiv, + operator.mod, + operator.truediv, + ], +) +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +def test_list_date_to_numeric_arithmetic_raises_error( + op: Callable[[Any], Any], exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series] +) -> None: + l = pl.Series([1], dtype=pl.Date) # noqa: E741 + r = pl.Series([[1]], dtype=pl.List(pl.Int32)) + + exec_op(l.to_physical(), r, op) + + # TODO(_): Ideally this always raises InvalidOperationError. The TypeError + # is being raised by checks on the Python side that should be moved to Rust. + with pytest.raises((InvalidOperationError, TypeError)): + exec_op(l, r, op)