diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index 7d1d9e761504..2a5aa2c40e45 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -841,6 +841,17 @@ impl<'df> GroupBy<'df> { df.as_single_chunk_par(); Ok(df) } + + pub fn sliced(mut self, slice: Option<(i64, usize)>) -> Self { + match slice { + None => self, + Some((offset, length)) => { + self.groups = (*self.groups.slice(offset, length)).clone(); + self.selected_keys = self.keys_sliced(slice); + self + }, + } + } } unsafe fn take_df(df: &DataFrame, g: GroupsIndicator) -> DataFrame { diff --git a/crates/polars-mem-engine/src/executors/group_by.rs b/crates/polars-mem-engine/src/executors/group_by.rs index 437b7fb574aa..f7a501424ed9 100644 --- a/crates/polars-mem-engine/src/executors/group_by.rs +++ b/crates/polars-mem-engine/src/executors/group_by.rs @@ -67,7 +67,7 @@ pub(super) fn group_by_helper( let gb = df.group_by_with_series(keys, true, maintain_order)?; if let Some(f) = apply { - return gb.apply(move |df| f.call_udf(df)); + return gb.sliced(slice).apply(move |df| f.call_udf(df)); } let mut groups = gb.get_groups(); diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index cff43b43274c..645b978214f0 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -1153,3 +1153,31 @@ def test_group_by_agg_19173() -> None: out = df.head(0).group_by("g").agg((pl.col.x - pl.col.x.sum() * pl.col.x) ** 2) assert out.to_dict(as_series=False) == {"g": [], "x": []} assert out.schema == pl.Schema([("g", pl.Int64), ("x", pl.List(pl.Float64))]) + + +def test_group_by_map_groups_slice_pushdown_20002() -> None: + schema = { + "a": pl.Int8, + "b": pl.UInt8, + } + + df = ( + pl.LazyFrame( + data={"a": [1, 2, 3, 4, 5], "b": [90, 80, 70, 60, 50]}, + schema=schema, + ) + .group_by("a", maintain_order=True) + .map_groups(lambda df: df * 2.0, schema=schema) + .head(3) + .collect() + ) + + assert_frame_equal( + df, + pl.DataFrame( + { + "a": [2.0, 4.0, 6.0], + "b": [180.0, 160.0, 140.0], + } + ), + )