From 0e5c03d0333c67dc413ef68a48827ca5d1850155 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 30 Nov 2022 10:20:21 +0100 Subject: [PATCH 1/2] fix(rust, python): fix boolean schema in agg_max/min --- .../src/chunked_array/upstream_traits.rs | 15 +++ .../src/frame/groupby/aggregations/mod.rs | 97 ++++++++++++++++++- py-polars/tests/unit/test_schema.py | 18 ++++ 3 files changed, 128 insertions(+), 2 deletions(-) diff --git a/polars/polars-core/src/chunked_array/upstream_traits.rs b/polars/polars-core/src/chunked_array/upstream_traits.rs index 63ac8aead8bb..fb2d654b195a 100644 --- a/polars/polars-core/src/chunked_array/upstream_traits.rs +++ b/polars/polars-core/src/chunked_array/upstream_traits.rs @@ -435,6 +435,21 @@ impl FromParallelIterator for BooleanChunked { } } +impl FromParallelIterator> for BooleanChunked { + fn from_par_iter>>(iter: I) -> Self { + let vectors = collect_into_linked_list(iter); + + let capacity: usize = get_capacity_from_par_results(&vectors); + + let arr = unsafe { + BooleanArray::from_trusted_len_iter( + vectors.into_iter().flatten().trust_my_length(capacity), + ) + }; + Self::from_chunks("", vec![Box::new(arr)]) + } +} + impl FromParallelIterator for Utf8Chunked where Ptr: PolarsAsRef + Send + Sync, diff --git a/polars/polars-core/src/frame/groupby/aggregations/mod.rs b/polars/polars-core/src/frame/groupby/aggregations/mod.rs index bf1052807db9..d35c102f5faf 100644 --- a/polars/polars-core/src/frame/groupby/aggregations/mod.rs +++ b/polars/polars-core/src/frame/groupby/aggregations/mod.rs @@ -161,6 +161,14 @@ where ca.into_series() } +pub fn _agg_helper_idx_bool(groups: &GroupsIdx, f: F) -> Series +where + F: Fn((IdxSize, &Vec)) -> Option + Send + Sync, +{ + let ca: BooleanChunked = POOL.install(|| groups.into_par_iter().map(f).collect()); + ca.into_series() +} + // helper that iterates on the `all: Vec` collection // this doesn't have traverse the `first: Vec` memory and is therefore faster fn agg_helper_idx_on_all(groups: &GroupsIdx, f: F) -> Series @@ -183,12 +191,97 @@ where ca.into_series() } +pub fn _agg_helper_slice_bool(groups: &[[IdxSize; 2]], f: F) -> Series +where + F: Fn([IdxSize; 2]) -> Option + Send + Sync, +{ + let ca: BooleanChunked = POOL.install(|| groups.par_iter().copied().map(f).collect()); + ca.into_series() +} + impl BooleanChunked { pub(crate) unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { - self.cast(&IDX_DTYPE).unwrap().agg_min(groups) + // faster paths + match (self.is_sorted2(), self.null_count()) { + (IsSorted::Ascending, 0) => { + return self.clone().into_series().agg_first(groups); + } + (IsSorted::Descending, 0) => { + return self.clone().into_series().agg_last(groups); + } + _ => {} + } + match groups { + GroupsProxy::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| { + debug_assert!(idx.len() <= self.len()); + if idx.is_empty() { + None + } else if idx.len() == 1 { + self.get(first as usize) + } else { + // TODO! optimize this + // can just check if any is false and early stop + let take = { self.take_unchecked(idx.into()) }; + take.min().map(|v| v == 1) + } + }), + GroupsProxy::Slice { + groups: groups_slice, + .. + } => _agg_helper_slice_bool(groups_slice, |[first, len]| { + debug_assert!(len <= self.len() as IdxSize); + match len { + 0 => None, + 1 => self.get(first as usize), + _ => { + let arr_group = _slice_from_offsets(self, first, len); + arr_group.min().map(|v| v == 1) + } + } + }), + } } pub(crate) unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { - self.cast(&IDX_DTYPE).unwrap().agg_max(groups) + // faster paths + match (self.is_sorted2(), self.null_count()) { + (IsSorted::Ascending, 0) => { + return self.clone().into_series().agg_last(groups); + } + (IsSorted::Descending, 0) => { + return self.clone().into_series().agg_first(groups); + } + _ => {} + } + + match groups { + GroupsProxy::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| { + debug_assert!(idx.len() <= self.len()); + if idx.is_empty() { + None + } else if idx.len() == 1 { + self.get(first as usize) + } else { + // TODO! optimize this + // can just check if any is true and early stop + let take = { self.take_unchecked(idx.into()) }; + take.max().map(|v| v == 1) + } + }), + GroupsProxy::Slice { + groups: groups_slice, + .. + } => _agg_helper_slice_bool(groups_slice, |[first, len]| { + debug_assert!(len <= self.len() as IdxSize); + match len { + 0 => None, + 1 => self.get(first as usize), + _ => { + let arr_group = _slice_from_offsets(self, first, len); + arr_group.max().map(|v| v == 1) + } + } + }), + } } pub(crate) unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { self.cast(&IDX_DTYPE).unwrap().agg_sum(groups) diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 8221de45f924..7123741710cc 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -249,3 +249,21 @@ def test_diff_duration_dtype() -> None: False, True, ] + + +def test_boolean_agg_schema() -> None: + df = pl.DataFrame( + { + "x": [1, 1, 1], + "y": [False, True, False], + } + ).lazy() + + agg_df = df.groupby("x").agg(pl.col("y").max().alias("max_y")) + + for streaming in [True, False]: + assert ( + agg_df.collect(streaming=streaming).schema + == agg_df.schema + == {"x": pl.Int64, "max_y": pl.Boolean} + ) From 53f86c737169b0b76b2b37b37f7dec2bc2db1e7b Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 30 Nov 2022 11:00:36 +0100 Subject: [PATCH 2/2] fix doctest --- .../polars/internals/dataframe/groupby.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/py-polars/polars/internals/dataframe/groupby.py b/py-polars/polars/internals/dataframe/groupby.py index 8b732dc4a0f5..79f1bf63aa44 100644 --- a/py-polars/polars/internals/dataframe/groupby.py +++ b/py-polars/polars/internals/dataframe/groupby.py @@ -600,17 +600,17 @@ def min(self) -> pli.DataFrame: ... ) >>> df.groupby("d", maintain_order=True).min() shape: (3, 4) - ┌────────┬─────┬──────┬─────┐ - │ d ┆ a ┆ b ┆ c │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ str ┆ i64 ┆ f64 ┆ u32 │ - ╞════════╪═════╪══════╪═════╡ - │ Apple ┆ 1 ┆ 0.5 ┆ 0 │ - ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┤ - │ Orange ┆ 2 ┆ 0.5 ┆ 1 │ - ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┤ - │ Banana ┆ 4 ┆ 13.0 ┆ 0 │ - └────────┴─────┴──────┴─────┘ + ┌────────┬─────┬──────┬───────┐ + │ d ┆ a ┆ b ┆ c │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ str ┆ i64 ┆ f64 ┆ bool │ + ╞════════╪═════╪══════╪═══════╡ + │ Apple ┆ 1 ┆ 0.5 ┆ false │ + ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ Orange ┆ 2 ┆ 0.5 ┆ true │ + ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ Banana ┆ 4 ┆ 13.0 ┆ false │ + └────────┴─────┴──────┴───────┘ """ return self.agg(pli.all().min()) @@ -631,17 +631,17 @@ def max(self) -> pli.DataFrame: ... ) >>> df.groupby("d", maintain_order=True).max() shape: (3, 4) - ┌────────┬─────┬──────┬─────┐ - │ d ┆ a ┆ b ┆ c │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ str ┆ i64 ┆ f64 ┆ u32 │ - ╞════════╪═════╪══════╪═════╡ - │ Apple ┆ 3 ┆ 10.0 ┆ 1 │ - ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┤ - │ Orange ┆ 2 ┆ 0.5 ┆ 1 │ - ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┤ - │ Banana ┆ 5 ┆ 14.0 ┆ 1 │ - └────────┴─────┴──────┴─────┘ + ┌────────┬─────┬──────┬──────┐ + │ d ┆ a ┆ b ┆ c │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ str ┆ i64 ┆ f64 ┆ bool │ + ╞════════╪═════╪══════╪══════╡ + │ Apple ┆ 3 ┆ 10.0 ┆ true │ + ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ Orange ┆ 2 ┆ 0.5 ┆ true │ + ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ Banana ┆ 5 ┆ 14.0 ┆ true │ + └────────┴─────┴──────┴──────┘ """ return self.agg(pli.all().max())