Skip to content

Commit

Permalink
fix: Incorrectly gave list.len() for masked-out rows (#19999)
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Nov 26, 2024
1 parent 899881a commit e5f0b97
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
11 changes: 10 additions & 1 deletion crates/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,13 @@ pub trait ListNameSpaceImpl: AsList {

fn lst_lengths(&self) -> IdxCa {
let ca = self.as_list();

let ca_validity = ca.rechunk_validity();

if ca_validity.as_ref().map_or(false, |x| x.set_bits() == 0) {
return IdxCa::full_null(ca.name().clone(), ca.len());
}

let mut lengths = Vec::with_capacity(ca.len());
ca.downcast_iter().for_each(|arr| {
let offsets = arr.offsets().as_slice();
Expand All @@ -335,7 +342,9 @@ pub trait ListNameSpaceImpl: AsList {
last = *o;
}
});
IdxCa::from_vec(ca.name().clone(), lengths)

let arr = IdxArr::from_vec(lengths).with_validity(ca_validity);
IdxCa::with_chunk(ca.name().clone(), arr)
}

/// Get the value by index in the sublists.
Expand Down
30 changes: 23 additions & 7 deletions py-polars/tests/unit/operations/namespaces/list/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,13 +754,6 @@ def test_utf8_empty_series_arg_min_max_10703() -> None:
}


def test_list_len() -> None:
s = pl.Series([[1, 2, None], [5]])
result = s.list.len()
expected = pl.Series([3, 1], dtype=pl.UInt32)
assert_series_equal(result, expected)


def test_list_to_array() -> None:
data = [[1.0, 2.0], [3.0, 4.0]]
s = pl.Series(data, dtype=pl.List(pl.Float32))
Expand Down Expand Up @@ -804,13 +797,36 @@ def test_list_to_array_wrong_dtype() -> None:


def test_list_lengths() -> None:
s = pl.Series([[1, 2, None], [5]])
result = s.list.len()
expected = pl.Series([3, 1], dtype=pl.UInt32)
assert_series_equal(result, expected)

s = pl.Series("a", [[1, 2], [1, 2, 3]])
assert_series_equal(s.list.len(), pl.Series("a", [2, 3], dtype=pl.UInt32))
df = pl.DataFrame([s])
assert_series_equal(
df.select(pl.col("a").list.len())["a"], pl.Series("a", [2, 3], dtype=pl.UInt32)
)

assert_series_equal(
pl.select(
pl.when(pl.Series([True, False]))
.then(pl.Series([[1, 1], [1, 1]]))
.list.len()
).to_series(),
pl.Series([2, None], dtype=pl.UInt32),
)

assert_series_equal(
pl.select(
pl.when(pl.Series([False, False]))
.then(pl.Series([[1, 1], [1, 1]]))
.list.len()
).to_series(),
pl.Series([None, None], dtype=pl.UInt32),
)


def test_list_arithmetic() -> None:
s = pl.Series("a", [[1, 2], [1, 2, 3]])
Expand Down

0 comments on commit e5f0b97

Please sign in to comment.