Skip to content

Commit

Permalink
Fix cudf.polars sum of empty not equalling zero
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke committed Jan 6, 2025
1 parent c4f2e8e commit 19a10ba
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
14 changes: 13 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
op = partial(self._reduce, request=req)
elif name in {"min", "max"}:
op = partial(op, propagate_nans=options)
elif name in {"count", "first", "last"}:
elif name in {"count", "sum", "first", "last"}:
pass
else:
raise NotImplementedError(
Expand Down Expand Up @@ -180,6 +180,18 @@ def _count(self, column: Column) -> Column:
)
)

def _sum(self, column: Column) -> Column:
if column.obj.size() == 0:
return Column(
plc.Column.from_scalar(
plc.interop.from_arrow(
pa.scalar(0, type=plc.interop.to_arrow(self.dtype))
),
1,
)
)
return self._reduce(column, request=plc.aggregation.sum())

def _min(self, column: Column, *, propagate_nans: bool) -> Column:
if propagate_nans and column.nan_count > 0:
return Column(
Expand Down
8 changes: 7 additions & 1 deletion python/cudf_polars/tests/expressions/test_agg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

Expand Down Expand Up @@ -148,3 +148,9 @@ def test_agg_singleton(op):
q = df.select(op(pl.col("a")))

assert_gpu_result_equal(q)


def test_sum_empty_zero():
df = pl.LazyFrame({"a": pl.Series(values=[], dtype=pl.Int32())})
q = df.select(pl.col("a").sum())
assert_gpu_result_equal(q)

0 comments on commit 19a10ba

Please sign in to comment.