From 6ea587d7361fb9889c9f7906c0c04f5484209a8c Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Tue, 4 Jun 2024 09:39:16 +0100 Subject: [PATCH] feat!: Preserve nulls in `ewm_mean`, `ewm_std`, and `ewm_var` (#15503) --- .../src/legacy/kernels/ewm/average.rs | 15 +- .../src/legacy/kernels/ewm/variance.rs | 129 ++++++++---------- py-polars/tests/unit/operations/test_ewm.py | 35 +++-- 3 files changed, 91 insertions(+), 88 deletions(-) diff --git a/crates/polars-arrow/src/legacy/kernels/ewm/average.rs b/crates/polars-arrow/src/legacy/kernels/ewm/average.rs index 04be94939cbe..6d5aed26061f 100644 --- a/crates/polars-arrow/src/legacy/kernels/ewm/average.rs +++ b/crates/polars-arrow/src/legacy/kernels/ewm/average.rs @@ -46,9 +46,10 @@ where } }, } - match non_null_cnt < min_periods { - true => None, - false => weighted_avg, + match (non_null_cnt < min_periods, opt_x.is_some()) { + (_, false) => None, + (true, true) => None, + (false, true) => weighted_avg, } }) .collect_trusted() @@ -111,7 +112,7 @@ mod test { None, Some(5.0), Some(6.333_333_333_333_333), - Some(6.333_333_333_333_333), + None, Some(3.857_142_857_142_857), Some(2.333_333_333_333_333_5), Some(3.193_548_387_096_774), @@ -125,7 +126,7 @@ mod test { None, Some(5.0), Some(6.333_333_333_333_333), - Some(6.333_333_333_333_333), + None, Some(3.181_818_181_818_181_7), Some(1.888_888_888_888_888_8), Some(3.033_898_305_084_745_7), @@ -139,7 +140,7 @@ mod test { None, Some(5.0), Some(6.0), - Some(6.0), + None, Some(4.0), Some(2.5), Some(3.25), @@ -153,7 +154,7 @@ mod test { None, Some(5.0), Some(6.0), - Some(6.0), + None, Some(3.333_333_333_333_333_5), Some(2.166_666_666_666_667), Some(3.083_333_333_333_333_5), diff --git a/crates/polars-arrow/src/legacy/kernels/ewm/variance.rs b/crates/polars-arrow/src/legacy/kernels/ewm/variance.rs index 0aabb72c10a3..ab4112697839 100644 --- a/crates/polars-arrow/src/legacy/kernels/ewm/variance.rs +++ b/crates/polars-arrow/src/legacy/kernels/ewm/variance.rs @@ -100,9 +100,10 @@ where } }, } - match (non_na_cnt >= min_periods_fixed, bias) { - (false, _) => None, - (true, false) => { + match (non_na_cnt >= min_periods_fixed, bias, is_observation) { + (_, _, false) => None, + (false, _, true) => None, + (true, false, true) => { if non_na_cnt == 1 { Some(cov) } else { @@ -115,7 +116,7 @@ where } } }, - (true, true) => Some(cov), + (true, true, true) => Some(cov), } }); @@ -346,8 +347,8 @@ mod test { None, Some(0.0), Some(0.888_888_888_888_889), - Some(0.888_888_888_888_889), - Some(0.888_888_888_888_889), + None, + None, Some(7.346_938_775_510_203), Some(3.555_555_555_555_555_4), ]), @@ -359,8 +360,8 @@ mod test { None, Some(0.0), Some(0.888_888_888_888_889), - Some(0.888_888_888_888_889), - Some(0.888_888_888_888_889), + None, + None, Some(3.922_437_673_130_193_3), Some(2.549_788_542_868_127_3), ]), @@ -372,8 +373,8 @@ mod test { None, Some(0.0), Some(2.0), - Some(2.0), - Some(2.0), + None, + None, Some(12.857_142_857_142_856), Some(5.714_285_714_285_714), ]), @@ -385,8 +386,8 @@ mod test { None, Some(0.0), Some(2.0), - Some(2.0), - Some(2.0), + None, + None, Some(14.159_999_999_999_997), Some(5.039_513_677_811_549_5), ]), @@ -398,8 +399,8 @@ mod test { None, Some(0.0), Some(1.0), - Some(1.0), - Some(1.0), + None, + None, Some(6.75), Some(3.437_5), ]), @@ -407,15 +408,7 @@ mod test { ); assert_allclose!( ewm_var(YS.to_vec(), ALPHA, false, true, 0, false), - PrimitiveArray::from([ - None, - Some(0.0), - Some(1.0), - Some(1.0), - Some(1.0), - Some(4.2), - Some(3.1), - ]), + PrimitiveArray::from([None, Some(0.0), Some(1.0), None, None, Some(4.2), Some(3.1),]), EPS ); assert_allclose!( @@ -424,8 +417,8 @@ mod test { None, Some(0.0), Some(2.0), - Some(2.0), - Some(2.0), + None, + None, Some(10.8), Some(5.238_095_238_095_238), ]), @@ -437,8 +430,8 @@ mod test { None, Some(0.0), Some(2.0), - Some(2.0), - Some(2.0), + None, + None, Some(12.352_941_176_470_589), Some(5.299_145_299_145_3), ]), @@ -454,8 +447,8 @@ mod test { None, Some(0.0), Some(0.888_888_888_888_889), - Some(0.888_888_888_888_889), - Some(0.888_888_888_888_889), + None, + None, Some(7.346_938_775_510_203), Some(3.555_555_555_555_555_4) ]), @@ -467,8 +460,8 @@ mod test { None, Some(0.0), Some(0.888_888_888_888_889), - Some(0.888_888_888_888_889), - Some(0.888_888_888_888_889), + None, + None, Some(3.922_437_673_130_193_3), Some(2.549_788_542_868_127_3) ]), @@ -480,8 +473,8 @@ mod test { None, Some(0.0), Some(2.0), - Some(2.0), - Some(2.0), + None, + None, Some(12.857_142_857_142_856), Some(5.714_285_714_285_714) ]), @@ -493,8 +486,8 @@ mod test { None, Some(0.0), Some(2.0), - Some(2.0), - Some(2.0), + None, + None, Some(14.159_999_999_999_997), Some(5.039_513_677_811_549_5) ]), @@ -506,8 +499,8 @@ mod test { None, Some(0.0), Some(1.0), - Some(1.0), - Some(1.0), + None, + None, Some(6.75), Some(3.437_5) ]), @@ -515,15 +508,7 @@ mod test { ); assert_allclose!( ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, true, 0, false), - PrimitiveArray::from([ - None, - Some(0.0), - Some(1.0), - Some(1.0), - Some(1.0), - Some(4.2), - Some(3.1) - ]), + PrimitiveArray::from([None, Some(0.0), Some(1.0), None, None, Some(4.2), Some(3.1)]), EPS ); assert_allclose!( @@ -532,8 +517,8 @@ mod test { None, Some(0.0), Some(2.0), - Some(2.0), - Some(2.0), + None, + None, Some(10.8), Some(5.238_095_238_095_238) ]), @@ -545,8 +530,8 @@ mod test { None, Some(0.0), Some(2.0), - Some(2.0), - Some(2.0), + None, + None, Some(12.352_941_176_470_589), Some(5.299_145_299_145_3) ]), @@ -666,8 +651,8 @@ mod test { None, Some(0.0), Some(0.942_809_041_582_063_4), - Some(0.942_809_041_582_063_4), - Some(0.942_809_041_582_063_4), + None, + None, Some(2.710_523_708_715_753_4), Some(1.885_618_083_164_126_7), ]), @@ -679,8 +664,8 @@ mod test { None, Some(0.0), Some(0.942_809_041_582_063_4), - Some(0.942_809_041_582_063_4), - Some(0.942_809_041_582_063_4), + None, + None, Some(1.980_514_497_076_503), Some(1.596_805_731_098_222), ]), @@ -692,8 +677,8 @@ mod test { None, Some(0.0), Some(SQRT_2), - Some(SQRT_2), - Some(SQRT_2), + None, + None, Some(3.585_685_828_003_181), Some(2.390_457_218_668_787), ]), @@ -705,8 +690,8 @@ mod test { None, Some(0.0), Some(SQRT_2), - Some(SQRT_2), - Some(SQRT_2), + None, + None, Some(3.762_977_544_445_355_3), Some(2.244_886_116_891_356), ]), @@ -718,8 +703,8 @@ mod test { None, Some(0.0), Some(1.0), - Some(1.0), - Some(1.0), + None, + None, Some(2.598_076_211_353_316), Some(1.854_049_621_773_915_7), ]), @@ -731,8 +716,8 @@ mod test { None, Some(0.0), Some(1.0), - Some(1.0), - Some(1.0), + None, + None, Some(2.049_390_153_191_92), Some(1.760_681_686_165_901), ]), @@ -744,8 +729,8 @@ mod test { None, Some(0.0), Some(SQRT_2), - Some(SQRT_2), - Some(SQRT_2), + None, + None, Some(3.286_335_345_030_997), Some(2.288_688_541_085_317_5), ]), @@ -757,8 +742,8 @@ mod test { None, Some(0.0), Some(SQRT_2), - Some(SQRT_2), - Some(SQRT_2), + None, + None, Some(3.514_675_116_774_036_7), Some(2.301_987_249_996_250_4), ]), @@ -774,8 +759,8 @@ mod test { None, Some(0.0), Some(0.888_888_888_888_889), - Some(0.888_888_888_888_889), - Some(0.888_888_888_888_889), + None, + None, Some(3.922_437_673_130_193_3), Some(2.549_788_542_868_127_3), ]), @@ -787,8 +772,8 @@ mod test { None, Some(0.0), Some(0.888_888_888_888_889), - Some(0.888_888_888_888_889), - Some(0.888_888_888_888_889), + None, + None, Some(3.922_437_673_130_193_3), Some(2.549_788_542_868_127_3), ]), @@ -800,8 +785,8 @@ mod test { None, None, Some(0.888_888_888_888_889), - Some(0.888_888_888_888_889), - Some(0.888_888_888_888_889), + None, + None, Some(3.922_437_673_130_193_3), Some(2.549_788_542_868_127_3), ]), diff --git a/py-polars/tests/unit/operations/test_ewm.py b/py-polars/tests/unit/operations/test_ewm.py index 57b2e32b10d3..715fb9cd0482 100644 --- a/py-polars/tests/unit/operations/test_ewm.py +++ b/py-polars/tests/unit/operations/test_ewm.py @@ -60,7 +60,7 @@ def test_ewm_mean() -> None: 1.0, 3.6666666666666665, 5.571428571428571, - 5.571428571428571, + None, 3.6666666666666665, 4.354838709677419, 4.174603174603175, @@ -73,7 +73,7 @@ def test_ewm_mean() -> None: 1.0, 3.666666666666667, 5.571428571428571, - 5.571428571428571, + None, 3.08695652173913, 4.2, 4.092436974789916, @@ -83,12 +83,12 @@ def test_ewm_mean() -> None: s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=False), expected ) - expected = pl.Series([None, 1.0, 3.0, 5.0, 5.0, 3.5, 4.25, 4.125]) + expected = pl.Series([None, 1.0, 3.0, 5.0, None, 3.5, 4.25, 4.125]) assert_series_equal( s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=True), expected ) - expected = pl.Series([None, 1.0, 3.0, 5.0, 5.0, 3.0, 4.0, 4.0]) + expected = pl.Series([None, 1.0, 3.0, 5.0, None, 3.0, 4.0, 4.0]) assert_series_equal( s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=False), expected ) @@ -114,7 +114,7 @@ def test_ewm_mean_min_periods() -> None: series = pl.Series([1.0, None, None, None]) ewm_mean = series.ewm_mean(alpha=0.5, min_periods=1, ignore_nulls=True) - assert ewm_mean.to_list() == [1.0, 1.0, 1.0, 1.0] + assert ewm_mean.to_list() == [1.0, None, None, None] ewm_mean = series.ewm_mean(alpha=0.5, min_periods=2, ignore_nulls=True) assert ewm_mean.to_list() == [None, None, None, None] @@ -126,9 +126,9 @@ def test_ewm_mean_min_periods() -> None: pl.Series( [ 1.0, - 1.0, - 1.6666666666666665, + None, 1.6666666666666665, + None, 2.4285714285714284, ] ), @@ -141,7 +141,7 @@ def test_ewm_mean_min_periods() -> None: None, None, 1.6666666666666665, - 1.6666666666666665, + None, 2.4285714285714284, ] ), @@ -153,8 +153,25 @@ def test_ewm_std_var() -> None: var = series.ewm_var(alpha=0.5, ignore_nulls=False) std = series.ewm_std(alpha=0.5, ignore_nulls=False) - + expected = pl.Series("a", [0, 4.5, 1.9285714285714288]) assert np.allclose(var, std**2, rtol=1e-16) + assert_series_equal(var, expected) + + +def test_ewm_std_var_with_nulls() -> None: + series = pl.Series("a", [2, 5, None, 3]) + + var = series.ewm_var(alpha=0.5, ignore_nulls=True) + std = series.ewm_std(alpha=0.5, ignore_nulls=True) + expected = pl.Series("a", [0, 4.5, None, 1.9285714285714288]) + assert_series_equal(var, expected) + assert_series_equal(std**2, expected) + + var = series.ewm_var(alpha=0.5, ignore_nulls=False) + std = series.ewm_std(alpha=0.5, ignore_nulls=False) + expected = pl.Series("a", [0, 4.5, None, 1.7307692307692308]) + assert_series_equal(var, expected) + assert_series_equal(std**2, expected) def test_ewm_param_validation() -> None: