Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support arithmetic between Series with dtype list #17823

Merged
merged 34 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
31c0e44
A more functional sketch.
pythonspeed Jul 19, 2024
559d9d2
Make division work for `list[int64]` (and arrays too)
pythonspeed Jul 23, 2024
f205d4f
More thorough testing of array math expressions
pythonspeed Jul 23, 2024
53dd233
Another test, commented out
pythonspeed Jul 23, 2024
b58f7f4
Success case test suite for list arithmetic
pythonspeed Jul 23, 2024
6e888ee
Include reference to Rust code
pythonspeed Jul 23, 2024
0ef2846
Support division in lists
pythonspeed Jul 23, 2024
bffb77f
Test error edge cases for List arithmetic
pythonspeed Jul 23, 2024
856a884
Run ruff to fix formatting
pythonspeed Jul 23, 2024
c830d33
Fix lints
pythonspeed Jul 23, 2024
5692b62
Fix lint
pythonspeed Jul 23, 2024
e194f97
Clean up
pythonspeed Jul 23, 2024
0601f5e
Specify dtype explicitly
pythonspeed Jul 23, 2024
2a86fbc
Rewrite to operate directly on underlying data in one chunk.
pythonspeed Aug 7, 2024
f0eea11
Handle nulls correctly
pythonspeed Aug 8, 2024
7ba7fd6
WIP improvements to null handling.
pythonspeed Aug 8, 2024
ef8b39d
Merge remote-tracking branch 'origin/main' into 9188-list-arithmetic
pythonspeed Aug 8, 2024
00ba975
Null handling now appears to work with latest tests.
pythonspeed Aug 8, 2024
254b37e
All tests pass.
pythonspeed Aug 8, 2024
d1d3950
Merge branch 'main' into 9188-list-arithmetic
itamarst Sep 9, 2024
cfd08f9
Merge remote-tracking branch 'origin/main' into 9188-list-arithmetic
pythonspeed Sep 11, 2024
cf4fa30
Update to compile with latest code.
pythonspeed Sep 11, 2024
03cdddd
Get rid of thread local, expand testing slightly.
pythonspeed Sep 12, 2024
19650ab
Drop scopeguard as explicit dependency.
pythonspeed Sep 12, 2024
4972d06
Simplify by getting rid of intermediate Series.
pythonspeed Sep 12, 2024
677f8d8
Merge remote-tracking branch 'origin/main' into 9188-list-arithmetic
pythonspeed Sep 16, 2024
ee74063
Simpler signature, better name.
pythonspeed Sep 16, 2024
b27b7ff
Use an AnonymousListBuilder.
pythonspeed Sep 16, 2024
b356683
Split list handling into its own module.
pythonspeed Sep 16, 2024
85cc6dd
Improve testing, and fix bug caught by the better test.
pythonspeed Sep 19, 2024
920fed2
There's an API for that.
pythonspeed Sep 20, 2024
0002d9a
Additional testing.
pythonspeed Sep 20, 2024
9e2e346
Remove a broken workaround I added, and replace it with actual fix fo…
pythonspeed Sep 20, 2024
ead35ac
Fix formatting
pythonspeed Sep 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions crates/polars-core/src/series/arithmetic/borrowed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,50 @@ impl NumOpsDispatchInner for FixedSizeListType {
}
}

impl ListChunked {
fn arithm_helper(
&self,
rhs: &Series,
op: &dyn Fn(&Series, &Series) -> PolarsResult<Series>,
) -> PolarsResult<Series> {
polars_ensure!(self.len() == rhs.len(), InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", self.len(), rhs.len());

let mut result = self.clear();
let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| {
itamarst marked this conversation as resolved.
Show resolved Hide resolved
// We ensured the original Series are the same length, so we can
// assume no None:
let a_owner = a.unwrap();
let b_owner = b.unwrap();
let a = a_owner.as_ref();
let b = b_owner.as_ref();
polars_ensure!(a.len() == b.len(), InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}", a.len(), b.len());
op(a, b).and_then(|s| s.implode()).map(Series::from)
});
for c in combined.into_iter() {
result.append(c?.list()?)?;
}
Ok(result.into())
}
}

impl NumOpsDispatchInner for ListType {
fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.add_to(r))
}
fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.subtract(r))
}
fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.multiply(r))
}
fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.divide(r))
}
fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.remainder(r))
}
}

#[cfg(feature = "checked_arithmetic")]
pub mod checked {
use num_traits::{CheckedDiv, One, ToPrimitive, Zero};
Expand Down
18 changes: 18 additions & 0 deletions crates/polars-core/src/series/implementations/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,24 @@ impl private::PrivateSeries for SeriesWrap<ListChunked> {
fn into_total_eq_inner<'a>(&'a self) -> Box<dyn TotalEqInner + 'a> {
(&self.0).into_total_eq_inner()
}

fn add_to(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.add_to(rhs)
}

fn subtract(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.subtract(rhs)
}

fn multiply(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.multiply(rhs)
}
fn divide(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.divide(rhs)
}
fn remainder(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.remainder(rhs)
}
}

impl SeriesTrait for SeriesWrap<ListChunked> {
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ pub fn apply_operator(left: &Series, right: &Series, op: Operator) -> PolarsResu
let right_dt = right.dtype().cast_leaf(Float64);
left.cast(&left_dt)? / right.cast(&right_dt)?
},
dt @ List(_) => {
let left_dt = dt.cast_leaf(Float64);
let right_dt = right.dtype().cast_leaf(Float64);
left.cast(&left_dt)? / right.cast(&right_dt)?
},
_ => {
if right.dtype().is_temporal() {
return left / right;
Expand Down
22 changes: 20 additions & 2 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,22 @@ def __sub__(self, other: Any) -> Self | Expr:
return F.lit(self) - other
return self._arithmetic(other, "sub", "sub_<>")

def _recursive_cast_to_float64(self) -> Series:
"""
Convert leaf dtypes to Float64 dtypes.

This is equivalent to logic in DataType::cast_leaf() in Rust.
"""

def convert_to_float64(dtype: PolarsDataType) -> PolarsDataType:
if isinstance(dtype, Array):
return Array(convert_to_float64(dtype.inner), shape=dtype.shape)
if isinstance(dtype, List):
return List(convert_to_float64(dtype.inner))
return Float64()

return self.cast(convert_to_float64(self.dtype))

@overload
def __truediv__(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
...
Expand All @@ -1077,9 +1093,11 @@ def __truediv__(self, other: Any) -> Series | Expr:

# this branch is exactly the floordiv function without rounding the floats
if self.dtype.is_float() or self.dtype == Decimal:
return self._arithmetic(other, "div", "div_<>")
as_float = self
else:
as_float = self._recursive_cast_to_float64()

return self.cast(Float64) / other
return as_float._arithmetic(other, "div", "div_<>")

@overload
def __floordiv__(self, other: Expr) -> Expr: # type: ignore[overload-overlap]
Expand Down
119 changes: 104 additions & 15 deletions py-polars/tests/unit/operations/arithmetic/test_arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import operator
from collections import OrderedDict
from datetime import date, datetime, timedelta
from typing import Any
from typing import Any, Callable

import numpy as np
import pytest
Expand Down Expand Up @@ -558,33 +560,120 @@ def test_power_series() -> None:


@pytest.mark.parametrize(
("expected", "expr"),
("expected", "expr", "column_names"),
[
(np.array([[2, 4], [6, 8]], dtype=np.int64), lambda a, b: a + b, ("a", "a")),
(np.array([[0, 0], [0, 0]], dtype=np.int64), lambda a, b: a - b, ("a", "a")),
(np.array([[1, 4], [9, 16]], dtype=np.int64), lambda a, b: a * b, ("a", "a")),
(
np.array([[2, 4], [6, 8]]),
pl.col("a") + pl.col("a"),
np.array([[1.0, 1.0], [1.0, 1.0]], dtype=np.float64),
lambda a, b: a / b,
("a", "a"),
),
(np.array([[0, 0], [0, 0]], dtype=np.int64), lambda a, b: a % b, ("a", "a")),
(
np.array([[0, 0], [0, 0]]),
pl.col("a") - pl.col("a"),
np.array([[3, 4], [7, 8]], dtype=np.int64),
lambda a, b: a + b,
("a", "uint8"),
),
# This fails because the code is buggy, see
# https://github.com/pola-rs/polars/issues/17820
#
# (
# np.array([[[2, 4]], [[6, 8]]], dtype=np.int64),
# lambda a, b: a + b,
# ("nested", "nested"),
# ),
],
)
def test_array_arithmetic_same_size(
expected: Any,
expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series],
column_names: tuple[str, str],
) -> None:
df = pl.DataFrame(
[
pl.Series("a", np.array([[1, 2], [3, 4]], dtype=np.int64)),
pl.Series("uint8", np.array([[2, 2], [4, 4]], dtype=np.uint8)),
pl.Series("nested", np.array([[[1, 2]], [[3, 4]]], dtype=np.int64)),
]
)
print(df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))))
# Expr-based arithmetic:
assert_frame_equal(
df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))),
pl.Series(column_names[0], expected).to_frame(),
)
# Direct arithmetic on the Series:
assert_series_equal(
expr(df[column_names[0]], df[column_names[1]]),
pl.Series(column_names[0], expected),
)


@pytest.mark.parametrize(
("expected", "expr", "column_names"),
[
([[2, 4], [6]], lambda a, b: a + b, ("a", "a")),
([[0, 0], [0]], lambda a, b: a - b, ("a", "a")),
([[1, 4], [9]], lambda a, b: a * b, ("a", "a")),
([[1.0, 1.0], [1.0]], lambda a, b: a / b, ("a", "a")),
([[0, 0], [0]], lambda a, b: a % b, ("a", "a")),
(
np.array([[1, 4], [9, 16]]),
pl.col("a") * pl.col("a"),
[[3, 4], [7]],
lambda a, b: a + b,
("a", "uint8"),
),
(
np.array([[1.0, 1.0], [1.0, 1.0]]),
pl.col("a") / pl.col("a"),
[[[2, 4]], [[6]]],
lambda a, b: a + b,
("nested", "nested"),
),
],
)
def test_array_arithmetic_same_size(expected: Any, expr: pl.Expr) -> None:
df = pl.Series("a", np.array([[1, 2], [3, 4]])).to_frame()

def test_list_arithmetic_same_size(
expected: Any,
expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series],
column_names: tuple[str, str],
) -> None:
print(expected)
df = pl.DataFrame(
[
pl.Series("a", [[1, 2], [3]]),
pl.Series("uint8", [[2, 2], [4]]),
pl.Series("nested", [[[1, 2]], [[3]]]),
]
)
# Expr-based arithmetic:
assert_frame_equal(
df.select(expr),
pl.Series("a", expected).to_frame(),
df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))),
pl.Series(column_names[0], expected).to_frame(),
)
# Direct arithmetic on the Series:
assert_series_equal(
expr(df[column_names[0]], df[column_names[1]]),
pl.Series(column_names[0], expected),
)


def test_list_arithmetic_error_cases() -> None:
# Different series length:
with pytest.raises(
InvalidOperationError, match="Series of the same size; got 1 and 2"
):
_ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1, 2], [3, 4]])

# Different list length:
# Different series length:
with pytest.raises(
InvalidOperationError, match="lists of the same size; got 2 and 1"
):
_ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1]])

# Wrong types:
# Different series length:
with pytest.raises(InvalidOperationError, match="cannot cast List type"):
_ = pl.Series("a", [[1, 2]]) + pl.Series("b", ["hello"])


def test_schema_owned_arithmetic_5669() -> None:
Expand Down