Skip to content

Commit

Permalink
feat!: Preserve nulls in ewm_mean, ewm_std, and ewm_var (#15503)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jun 4, 2024
1 parent 34d6fa3 commit 6ea587d
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 88 deletions.
15 changes: 8 additions & 7 deletions crates/polars-arrow/src/legacy/kernels/ewm/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ where
}
},
}
match non_null_cnt < min_periods {
true => None,
false => weighted_avg,
match (non_null_cnt < min_periods, opt_x.is_some()) {
(_, false) => None,
(true, true) => None,
(false, true) => weighted_avg,
}
})
.collect_trusted()
Expand Down Expand Up @@ -111,7 +112,7 @@ mod test {
None,
Some(5.0),
Some(6.333_333_333_333_333),
Some(6.333_333_333_333_333),
None,
Some(3.857_142_857_142_857),
Some(2.333_333_333_333_333_5),
Some(3.193_548_387_096_774),
Expand All @@ -125,7 +126,7 @@ mod test {
None,
Some(5.0),
Some(6.333_333_333_333_333),
Some(6.333_333_333_333_333),
None,
Some(3.181_818_181_818_181_7),
Some(1.888_888_888_888_888_8),
Some(3.033_898_305_084_745_7),
Expand All @@ -139,7 +140,7 @@ mod test {
None,
Some(5.0),
Some(6.0),
Some(6.0),
None,
Some(4.0),
Some(2.5),
Some(3.25),
Expand All @@ -153,7 +154,7 @@ mod test {
None,
Some(5.0),
Some(6.0),
Some(6.0),
None,
Some(3.333_333_333_333_333_5),
Some(2.166_666_666_666_667),
Some(3.083_333_333_333_333_5),
Expand Down
129 changes: 57 additions & 72 deletions crates/polars-arrow/src/legacy/kernels/ewm/variance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,10 @@ where
}
},
}
match (non_na_cnt >= min_periods_fixed, bias) {
(false, _) => None,
(true, false) => {
match (non_na_cnt >= min_periods_fixed, bias, is_observation) {
(_, _, false) => None,
(false, _, true) => None,
(true, false, true) => {
if non_na_cnt == 1 {
Some(cov)
} else {
Expand All @@ -115,7 +116,7 @@ where
}
}
},
(true, true) => Some(cov),
(true, true, true) => Some(cov),
}
});

Expand Down Expand Up @@ -346,8 +347,8 @@ mod test {
None,
Some(0.0),
Some(0.888_888_888_888_889),
Some(0.888_888_888_888_889),
Some(0.888_888_888_888_889),
None,
None,
Some(7.346_938_775_510_203),
Some(3.555_555_555_555_555_4),
]),
Expand All @@ -359,8 +360,8 @@ mod test {
None,
Some(0.0),
Some(0.888_888_888_888_889),
Some(0.888_888_888_888_889),
Some(0.888_888_888_888_889),
None,
None,
Some(3.922_437_673_130_193_3),
Some(2.549_788_542_868_127_3),
]),
Expand All @@ -372,8 +373,8 @@ mod test {
None,
Some(0.0),
Some(2.0),
Some(2.0),
Some(2.0),
None,
None,
Some(12.857_142_857_142_856),
Some(5.714_285_714_285_714),
]),
Expand All @@ -385,8 +386,8 @@ mod test {
None,
Some(0.0),
Some(2.0),
Some(2.0),
Some(2.0),
None,
None,
Some(14.159_999_999_999_997),
Some(5.039_513_677_811_549_5),
]),
Expand All @@ -398,24 +399,16 @@ mod test {
None,
Some(0.0),
Some(1.0),
Some(1.0),
Some(1.0),
None,
None,
Some(6.75),
Some(3.437_5),
]),
EPS
);
assert_allclose!(
ewm_var(YS.to_vec(), ALPHA, false, true, 0, false),
PrimitiveArray::from([
None,
Some(0.0),
Some(1.0),
Some(1.0),
Some(1.0),
Some(4.2),
Some(3.1),
]),
PrimitiveArray::from([None, Some(0.0), Some(1.0), None, None, Some(4.2), Some(3.1),]),
EPS
);
assert_allclose!(
Expand All @@ -424,8 +417,8 @@ mod test {
None,
Some(0.0),
Some(2.0),
Some(2.0),
Some(2.0),
None,
None,
Some(10.8),
Some(5.238_095_238_095_238),
]),
Expand All @@ -437,8 +430,8 @@ mod test {
None,
Some(0.0),
Some(2.0),
Some(2.0),
Some(2.0),
None,
None,
Some(12.352_941_176_470_589),
Some(5.299_145_299_145_3),
]),
Expand All @@ -454,8 +447,8 @@ mod test {
None,
Some(0.0),
Some(0.888_888_888_888_889),
Some(0.888_888_888_888_889),
Some(0.888_888_888_888_889),
None,
None,
Some(7.346_938_775_510_203),
Some(3.555_555_555_555_555_4)
]),
Expand All @@ -467,8 +460,8 @@ mod test {
None,
Some(0.0),
Some(0.888_888_888_888_889),
Some(0.888_888_888_888_889),
Some(0.888_888_888_888_889),
None,
None,
Some(3.922_437_673_130_193_3),
Some(2.549_788_542_868_127_3)
]),
Expand All @@ -480,8 +473,8 @@ mod test {
None,
Some(0.0),
Some(2.0),
Some(2.0),
Some(2.0),
None,
None,
Some(12.857_142_857_142_856),
Some(5.714_285_714_285_714)
]),
Expand All @@ -493,8 +486,8 @@ mod test {
None,
Some(0.0),
Some(2.0),
Some(2.0),
Some(2.0),
None,
None,
Some(14.159_999_999_999_997),
Some(5.039_513_677_811_549_5)
]),
Expand All @@ -506,24 +499,16 @@ mod test {
None,
Some(0.0),
Some(1.0),
Some(1.0),
Some(1.0),
None,
None,
Some(6.75),
Some(3.437_5)
]),
EPS
);
assert_allclose!(
ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, true, 0, false),
PrimitiveArray::from([
None,
Some(0.0),
Some(1.0),
Some(1.0),
Some(1.0),
Some(4.2),
Some(3.1)
]),
PrimitiveArray::from([None, Some(0.0), Some(1.0), None, None, Some(4.2), Some(3.1)]),
EPS
);
assert_allclose!(
Expand All @@ -532,8 +517,8 @@ mod test {
None,
Some(0.0),
Some(2.0),
Some(2.0),
Some(2.0),
None,
None,
Some(10.8),
Some(5.238_095_238_095_238)
]),
Expand All @@ -545,8 +530,8 @@ mod test {
None,
Some(0.0),
Some(2.0),
Some(2.0),
Some(2.0),
None,
None,
Some(12.352_941_176_470_589),
Some(5.299_145_299_145_3)
]),
Expand Down Expand Up @@ -666,8 +651,8 @@ mod test {
None,
Some(0.0),
Some(0.942_809_041_582_063_4),
Some(0.942_809_041_582_063_4),
Some(0.942_809_041_582_063_4),
None,
None,
Some(2.710_523_708_715_753_4),
Some(1.885_618_083_164_126_7),
]),
Expand All @@ -679,8 +664,8 @@ mod test {
None,
Some(0.0),
Some(0.942_809_041_582_063_4),
Some(0.942_809_041_582_063_4),
Some(0.942_809_041_582_063_4),
None,
None,
Some(1.980_514_497_076_503),
Some(1.596_805_731_098_222),
]),
Expand All @@ -692,8 +677,8 @@ mod test {
None,
Some(0.0),
Some(SQRT_2),
Some(SQRT_2),
Some(SQRT_2),
None,
None,
Some(3.585_685_828_003_181),
Some(2.390_457_218_668_787),
]),
Expand All @@ -705,8 +690,8 @@ mod test {
None,
Some(0.0),
Some(SQRT_2),
Some(SQRT_2),
Some(SQRT_2),
None,
None,
Some(3.762_977_544_445_355_3),
Some(2.244_886_116_891_356),
]),
Expand All @@ -718,8 +703,8 @@ mod test {
None,
Some(0.0),
Some(1.0),
Some(1.0),
Some(1.0),
None,
None,
Some(2.598_076_211_353_316),
Some(1.854_049_621_773_915_7),
]),
Expand All @@ -731,8 +716,8 @@ mod test {
None,
Some(0.0),
Some(1.0),
Some(1.0),
Some(1.0),
None,
None,
Some(2.049_390_153_191_92),
Some(1.760_681_686_165_901),
]),
Expand All @@ -744,8 +729,8 @@ mod test {
None,
Some(0.0),
Some(SQRT_2),
Some(SQRT_2),
Some(SQRT_2),
None,
None,
Some(3.286_335_345_030_997),
Some(2.288_688_541_085_317_5),
]),
Expand All @@ -757,8 +742,8 @@ mod test {
None,
Some(0.0),
Some(SQRT_2),
Some(SQRT_2),
Some(SQRT_2),
None,
None,
Some(3.514_675_116_774_036_7),
Some(2.301_987_249_996_250_4),
]),
Expand All @@ -774,8 +759,8 @@ mod test {
None,
Some(0.0),
Some(0.888_888_888_888_889),
Some(0.888_888_888_888_889),
Some(0.888_888_888_888_889),
None,
None,
Some(3.922_437_673_130_193_3),
Some(2.549_788_542_868_127_3),
]),
Expand All @@ -787,8 +772,8 @@ mod test {
None,
Some(0.0),
Some(0.888_888_888_888_889),
Some(0.888_888_888_888_889),
Some(0.888_888_888_888_889),
None,
None,
Some(3.922_437_673_130_193_3),
Some(2.549_788_542_868_127_3),
]),
Expand All @@ -800,8 +785,8 @@ mod test {
None,
None,
Some(0.888_888_888_888_889),
Some(0.888_888_888_888_889),
Some(0.888_888_888_888_889),
None,
None,
Some(3.922_437_673_130_193_3),
Some(2.549_788_542_868_127_3),
]),
Expand Down
Loading

0 comments on commit 6ea587d

Please sign in to comment.