Skip to content

Commit

Permalink
fix: Raise informative error instead of panicking for list arithmetic…
Browse files Browse the repository at this point in the history
… on some invalid dtypes (#19841)
  • Loading branch information
nameexhaustion authored Nov 19, 2024
1 parent 9527510 commit 6c34d59
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 14 deletions.
10 changes: 5 additions & 5 deletions crates/polars-core/src/series/arithmetic/borrowed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ impl Add for &Series {
_struct_arithmetic(self, rhs, |a, b| a.add(b))
},
(DataType::List(_), _) | (_, DataType::List(_)) => {
list_borrowed::NumericListOp::add().execute(self, rhs)
list::NumericListOp::add().execute(self, rhs)
},
#[cfg(feature = "dtype-array")]
(DataType::Array(..), _) | (_, DataType::Array(..)) => {
Expand All @@ -514,7 +514,7 @@ impl Sub for &Series {
_struct_arithmetic(self, rhs, |a, b| a.sub(b))
},
(DataType::List(_), _) | (_, DataType::List(_)) => {
list_borrowed::NumericListOp::sub().execute(self, rhs)
list::NumericListOp::sub().execute(self, rhs)
},
#[cfg(feature = "dtype-array")]
(DataType::Array(..), _) | (_, DataType::Array(..)) => {
Expand Down Expand Up @@ -555,7 +555,7 @@ impl Mul for &Series {
Ok(out.with_name(self.name().clone()))
},
(DataType::List(_), _) | (_, DataType::List(_)) => {
list_borrowed::NumericListOp::mul().execute(self, rhs)
list::NumericListOp::mul().execute(self, rhs)
},
#[cfg(feature = "dtype-array")]
(DataType::Array(..), _) | (_, DataType::Array(..)) => {
Expand Down Expand Up @@ -592,7 +592,7 @@ impl Div for &Series {
| (_, Date)
| (_, Datetime(_, _)) => polars_bail!(opq = div, self.dtype(), rhs.dtype()),
(DataType::List(_), _) | (_, DataType::List(_)) => {
list_borrowed::NumericListOp::div().execute(self, rhs)
list::NumericListOp::div().execute(self, rhs)
},
#[cfg(feature = "dtype-array")]
(DataType::Array(..), _) | (_, DataType::Array(..)) => {
Expand Down Expand Up @@ -622,7 +622,7 @@ impl Rem for &Series {
_struct_arithmetic(self, rhs, |a, b| a.rem(b))
},
(DataType::List(_), _) | (_, DataType::List(_)) => {
list_borrowed::NumericListOp::rem().execute(self, rhs)
list::NumericListOp::rem().execute(self, rhs)
},
#[cfg(feature = "dtype-array")]
(DataType::Array(..), _) | (_, DataType::Array(..)) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,30 @@ mod inner {
let output_primitive_dtype =
op.0.try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)?;

fn is_list_type_at_all_levels(dtype: &DataType) -> bool {
match dtype {
DataType::List(inner) => is_list_type_at_all_levels(inner),
dt if dt.is_supported_list_arithmetic_input() => true,
_ => false,
}
}

let op_err_msg = |err_reason: &str| {
polars_err!(
InvalidOperation:
"cannot {} columns: {}: (left: {}, right: {})",
op.0.name(), err_reason, dtype_lhs, dtype_rhs,
)
};

let ensure_list_type_at_all_levels = |dtype: &DataType| {
if !is_list_type_at_all_levels(dtype) {
Err(op_err_msg("dtype was not list on all nesting levels"))
} else {
Ok(())
}
};

let (op_apply_type, output_dtype) = match (dtype_lhs, dtype_rhs) {
(l @ DataType::List(a), r @ DataType::List(b)) => {
// `get_arithmetic_field()` in the DSL checks this, but we also have to check here because if a user
Expand All @@ -191,9 +215,11 @@ mod inner {
(BinaryOpApplyType::ListToList, l)
},
(list_dtype @ DataType::List(_), x) if x.is_supported_list_arithmetic_input() => {
ensure_list_type_at_all_levels(list_dtype)?;
(BinaryOpApplyType::ListToPrimitive, list_dtype)
},
(x, list_dtype @ DataType::List(_)) if x.is_supported_list_arithmetic_input() => {
ensure_list_type_at_all_levels(list_dtype)?;
(BinaryOpApplyType::PrimitiveToList, list_dtype)
},
(l, r) => polars_bail!(
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-core/src/series/arithmetic/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mod bitops;
mod borrowed;
mod list_borrowed;
mod list;
mod owned;

use std::borrow::Cow;
Expand All @@ -9,7 +9,7 @@ use std::ops::{Add, Div, Mul, Rem, Sub};
pub use borrowed::*;
#[cfg(feature = "dtype-array")]
pub use fixed_size_list::NumericFixedSizeListOp;
pub use list_borrowed::NumericListOp;
pub use list::NumericListOp;
use num_traits::{Num, NumCast};
#[cfg(feature = "dtype-array")]
mod fixed_size_list;
Expand Down
13 changes: 9 additions & 4 deletions py-polars/tests/unit/operations/arithmetic/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,15 +422,20 @@ def test_array_arithmetic_dtype_mismatch(
with pytest.raises(InvalidOperationError, match="differing dtypes"):
exec_op(a, b, op.add)

s = pl.Series([[[1]], [[1]]], dtype=pl.Array(pl.List(pl.Int64), 1))
p = pl.Series([1], dtype=pl.Int64)
a = pl.Series([[[1]], [[1]]], dtype=pl.Array(pl.List(pl.Int64), 1))
b = pl.Series([1], dtype=pl.Int64)

with pytest.raises(
InvalidOperationError, match="dtype was not array on all nesting levels"
):
exec_op(s, s, op.add)
exec_op(a, a, op.add)

with pytest.raises(
InvalidOperationError, match="dtype was not array on all nesting levels"
):
exec_op(s, p, op.add)
exec_op(a, b, op.add)

with pytest.raises(
InvalidOperationError, match="dtype was not array on all nesting levels"
):
exec_op(b, a, op.add)
Original file line number Diff line number Diff line change
Expand Up @@ -534,18 +534,45 @@ def test_list_arithmetic_error_cases() -> None:
with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"):
_ = pl.Series("a", [[1, 2], [2, 3]]) / pl.Series("b", [[1], None])


@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
def test_list_arithmetic_invalid_dtypes(
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
) -> None:
import operator as op

a = pl.Series([[1, 2]])
b = pl.Series(["hello"])

# Wrong types:
with pytest.raises(
InvalidOperationError, match="add operation not supported for dtypes"
):
_ = pl.Series("a", [[1, 2]]) + pl.Series("b", ["hello"])
exec_op(a, b, op.add)

a = pl.Series("a", [[1]])
b = pl.Series("b", [[[1]]])

# Different nesting:
# list<->list is restricted to 1 level of nesting
with pytest.raises(
InvalidOperationError,
match="cannot add two list columns with non-numeric inner types",
):
_ = pl.Series("a", [[1]]) + pl.Series("b", [[[1]]])
exec_op(a, b, op.add)

# Ensure dtype is validated to be `List` at all nesting levels instead of panicking.
a = pl.Series([[[1]], [[1]]], dtype=pl.List(pl.Array(pl.Int64, 1)))
b = pl.Series([1], dtype=pl.Int64)

with pytest.raises(
InvalidOperationError, match="dtype was not list on all nesting levels"
):
exec_op(a, b, op.add)

with pytest.raises(
InvalidOperationError, match="dtype was not list on all nesting levels"
):
exec_op(b, a, op.add)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 6c34d59

Please sign in to comment.