diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs index ff5f2f89ff0d..f9644b1aef0c 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs @@ -576,6 +576,28 @@ impl<'a> PredicatePushDown<'a> { expr_arena, )) }, + FunctionIR::Unnest { columns } => { + let exclude = columns.iter().cloned().collect::>(); + + let local_predicates = + transfer_to_local_by_name(expr_arena, &mut acc_predicates, |x| { + exclude.contains(x) + }); + + let lp = self.pushdown_and_continue( + lp, + acc_predicates, + lp_arena, + expr_arena, + false, + )?; + Ok(self.optional_apply_predicate( + lp, + local_predicates, + lp_arena, + expr_arena, + )) + }, _ => self.pushdown_and_continue( lp, acc_predicates, diff --git a/py-polars/tests/unit/test_predicates.py b/py-polars/tests/unit/test_predicates.py index 49141fd282de..e752bacdf81d 100644 --- a/py-polars/tests/unit/test_predicates.py +++ b/py-polars/tests/unit/test_predicates.py @@ -511,3 +511,45 @@ def test_predicate_push_down_list_gather_17492() -> None: .filter(pl.col("val").list.get(1, null_on_oob=True) == 1) .explain() ) + + +def test_predicate_pushdown_struct_unnest_19632() -> None: + lf = pl.LazyFrame({"a": [{"a": 1, "b": 2}]}).unnest("a") + + q = lf.filter(pl.col("a") == 1) + plan = q.explain() + + assert "FILTER" in plan + assert plan.index("FILTER") < plan.index("UNNEST") + + assert_frame_equal( + q.collect(), + pl.DataFrame({"a": 1, "b": 2}), + ) + + # With `pl.struct()` + lf = pl.LazyFrame({"a": 1, "b": 2}).select(pl.struct(pl.all())).unnest("a") + + q = lf.filter(pl.col("a") == 1) + plan = q.explain() + + assert "FILTER" in plan + assert plan.index("FILTER") < plan.index("UNNEST") + + assert_frame_equal( + q.collect(), + pl.DataFrame({"a": 1, "b": 2}), + ) + + # With `value_counts()` + lf = pl.LazyFrame({"a": [1]}).select(pl.col("a").value_counts()).unnest("a") + + q = lf.filter(pl.col("a") == 1) + plan = q.explain() + + assert plan.index("FILTER") < plan.index("UNNEST") + + assert_frame_equal( + q.collect(), + pl.DataFrame({"a": 1, "count": 1}, schema={"a": pl.Int64, "count": pl.UInt32}), + )