From 1afd302ea885b2f5e08505d0054d08212c554be0 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 30 Nov 2024 18:36:42 +0000 Subject: [PATCH] fix: Return null instead of 0. for rolling_std when window contains a single element and ddof=1 and there are nulls elsewhere in the Series --- .../src/legacy/kernels/rolling/nulls/mod.rs | 31 +++++++------------ .../legacy/kernels/rolling/nulls/variance.rs | 4 +-- .../unit/operations/rolling/test_rolling.py | 8 +++++ 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mod.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mod.rs index 38f037f63f96..84fffa5654a2 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mod.rs @@ -168,38 +168,29 @@ mod test { let out = rolling_var(arr, 3, 1, false, None, None); let out = out.as_any().downcast_ref::>().unwrap(); - let out = out - .into_iter() - .map(|v| v.copied().unwrap()) - .collect::>(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); - assert_eq!(out, &[0.0, 0.0, 2.0, 12.5]); + assert_eq!(out, &[None, None, Some(2.0), Some(12.5)]); let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 })); let out = rolling_var(arr, 3, 1, false, None, testpars.clone()); let out = out.as_any().downcast_ref::>().unwrap(); - let out = out - .into_iter() - .map(|v| v.copied().unwrap()) - .collect::>(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); - assert_eq!(out, &[0.0, 0.0, 1.0, 6.25]); + assert_eq!(out, &[Some(0.0), Some(0.0), Some(1.0), Some(6.25)]); let out = rolling_var(arr, 4, 1, false, None, None); let out = out.as_any().downcast_ref::>().unwrap(); - let out = out - .into_iter() - .map(|v| v.copied().unwrap()) - .collect::>(); - assert_eq!(out, &[0.0, 0.0, 2.0, 6.333333333333334]); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, None, Some(2.0), Some(6.333333333333334)]); let out = rolling_var(arr, 4, 1, false, None, testpars.clone()); let out = out.as_any().downcast_ref::>().unwrap(); - let out = out - .into_iter() - .map(|v| v.copied().unwrap()) - .collect::>(); - assert_eq!(out, &[0.0, 0.0, 1.0, 4.222222222222222]); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!( + out, + &[Some(0.), Some(0.0), Some(1.0), Some(4.222222222222222)] + ); } #[test] diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs index 8252c8931c4f..5d303eab982c 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs @@ -178,10 +178,10 @@ impl< let denom = count - ddof; - if count == T::zero() { + if denom <= T::zero() { None } else if count == T::one() { - NumCast::from(0) + Some(T::zero()) } else if denom <= T::zero() { Some(T::infinity()) } else { diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index 8f0cbafe52fb..55f7d0f5a5cb 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -807,6 +807,14 @@ def test_rolling() -> None: ) +def test_rolling_std_nulls_min_periods_1_20076() -> None: + result = pl.Series([1, 2, None, 4]).rolling_std(3, min_periods=1) + expected = pl.Series( + [None, 0.7071067811865476, 0.7071067811865476, 1.4142135623730951] + ) + assert_series_equal(result, expected) + + def test_rolling_by_date() -> None: df = pl.DataFrame( {