diff --git a/polars/polars-core/src/series/mod.rs b/polars/polars-core/src/series/mod.rs index 4c8eeed45133..9cb53fc6ce69 100644 --- a/polars/polars-core/src/series/mod.rs +++ b/polars/polars-core/src/series/mod.rs @@ -232,7 +232,8 @@ impl Series { } /// Compute the sum of all values in this Series. - /// Returns `None` if the array is empty or only contains null values. + /// Returns `Some(0)` if the array is empty, and `None` if the array only + /// contains null values. /// /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is /// first cast to `Int64` to prevent overflow issues. @@ -511,11 +512,18 @@ impl Series { } /// Get the sum of the Series as a new Series of length 1. + /// Returns a Series with a single zeroed entry if self is an empty numeric series. /// /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is /// first cast to `Int64` to prevent overflow issues. pub fn sum_as_series(&self) -> Series { use DataType::*; + if self.is_empty() && self.dtype().is_numeric() { + return Series::new("", [0]) + .cast(self.dtype()) + .unwrap() + .sum_as_series(); + } match self.dtype() { Int8 | UInt8 | Int16 | UInt16 => self.cast(&Int64).unwrap().sum_as_series(), _ => self._sum_as_series(), diff --git a/py-polars/tests/unit/test_series.py b/py-polars/tests/unit/test_series.py index 761a23380650..740d883215a7 100644 --- a/py-polars/tests/unit/test_series.py +++ b/py-polars/tests/unit/test_series.py @@ -317,6 +317,16 @@ def test_arithmetic(s: pl.Series) -> None: 2**a +def test_arithmetic_empty() -> None: + series = pl.Series("a", []) + assert series.sum() == 0 + + +def test_arithmetic_null() -> None: + series = pl.Series("a", [None]) + assert series.sum() is None + + def test_power() -> None: a = pl.Series([1, 2], dtype=Int64) b = pl.Series([None, 2.0], dtype=Float64)