From a5724d6e45039c04d52a3923c30eabb9c9738495 Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Tue, 9 Jul 2024 17:11:43 +1000 Subject: [PATCH] fix: Fix predicate pushdown for `.list.(get|gather)` (#17511) --- .../polars-plan/src/dsl/function_expr/mod.rs | 4 +- .../optimizer/predicate_pushdown/join.rs | 2 - .../plans/optimizer/predicate_pushdown/mod.rs | 44 +++++++------- .../optimizer/predicate_pushdown/utils.rs | 57 +++++++------------ py-polars/tests/unit/test_predicates.py | 16 ++++++ 5 files changed, 62 insertions(+), 61 deletions(-) diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index ed41bcf60ff3..edd6ead39c53 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -72,12 +72,12 @@ use std::fmt::{Display, Formatter}; use std::hash::{Hash, Hasher}; #[cfg(feature = "dtype-array")] -pub(super) use array::ArrayFunction; +pub(crate) use array::ArrayFunction; #[cfg(feature = "cov")] pub(crate) use correlation::CorrelationMethod; #[cfg(feature = "fused")] pub(crate) use fused::FusedOperator; -pub(super) use list::ListFunction; +pub(crate) use list::ListFunction; use polars_core::prelude::*; #[cfg(feature = "random")] pub(crate) use random::RandomMethod; diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs index c9aee790f22c..c787336af375 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs @@ -193,8 +193,6 @@ pub(super) fn process_join( let mut filter_left = false; let mut filter_right = false; - debug_assert_aexpr_allows_predicate_pushdown(predicate.node(), expr_arena); - if !block_pushdown_left && check_input_node(predicate.node(), &schema_left, expr_arena) { insert_and_combine_predicate(&mut pushdown_left, &predicate, expr_arena); filter_left = true; diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs index c9bdccee1e53..03eb08213c37 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs @@ -90,7 +90,7 @@ impl<'a> PredicatePushDown<'a> { let input = inputs[inputs.len() - 1]; let (eligibility, alias_rename_map) = - pushdown_eligibility(&exprs, &acc_predicates, expr_arena)?; + pushdown_eligibility(&exprs, &[], &acc_predicates, expr_arena)?; let local_predicates = match eligibility { PushdownEligibility::Full => vec![], @@ -265,22 +265,28 @@ impl<'a> PredicatePushDown<'a> { let tmp_key = Arc::::from(&*temporary_unique_key(&acc_predicates)); acc_predicates.insert(tmp_key.clone(), predicate.clone()); - let local_predicates = - match pushdown_eligibility(&[], &acc_predicates, expr_arena)?.0 { - PushdownEligibility::Full => vec![], - PushdownEligibility::Partial { to_local } => { - let mut out = Vec::with_capacity(to_local.len()); - for key in to_local { - out.push(acc_predicates.remove(&key).unwrap()); - } - out - }, - PushdownEligibility::NoPushdown => { - let out = acc_predicates.drain().map(|t| t.1).collect(); - acc_predicates.clear(); - out - }, - }; + let local_predicates = match pushdown_eligibility( + &[], + &[(tmp_key.clone(), predicate.clone())], + &acc_predicates, + expr_arena, + )? + .0 + { + PushdownEligibility::Full => vec![], + PushdownEligibility::Partial { to_local } => { + let mut out = Vec::with_capacity(to_local.len()); + for key in to_local { + out.push(acc_predicates.remove(&key).unwrap()); + } + out + }, + PushdownEligibility::NoPushdown => { + let out = acc_predicates.drain().map(|t| t.1).collect(); + acc_predicates.clear(); + out + }, + }; if let Some(predicate) = acc_predicates.remove(&tmp_key) { insert_and_combine_predicate(&mut acc_predicates, &predicate, expr_arena); @@ -327,10 +333,6 @@ impl<'a> PredicatePushDown<'a> { file_options: options, output_schema, } => { - for e in acc_predicates.values() { - debug_assert_aexpr_allows_predicate_pushdown(e.node(), expr_arena); - } - let local_predicates = match &scan_type { #[cfg(feature = "parquet")] FileScan::Parquet { .. } => vec![], diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs index 00b312486f66..c5fcadc5ce80 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs @@ -137,6 +137,9 @@ where local_predicates } +/// Extends a stack of nodes with new nodes from `ae` (with some filtering), to support traversing +/// an expression tree to check predicate PD eligibility. Generally called repeatedly with the same +/// stack until all nodes are exhausted. fn check_and_extend_predicate_pd_nodes( stack: &mut Vec, ae: &AExpr, @@ -148,6 +151,22 @@ fn check_and_extend_predicate_pd_nodes( // rely on the height of the dataframe at this level and thus need // to block pushdown. AExpr::Literal(lit) => !lit.projects_as_scalar(), + // Rows that go OOB on get/gather may be filtered out in earlier operations, + // so we don't push these down. + AExpr::Function { + function: FunctionExpr::ListExpr(ListFunction::Get(false)), + .. + } => true, + #[cfg(feature = "list_gather")] + AExpr::Function { + function: FunctionExpr::ListExpr(ListFunction::Gather(false)), + .. + } => true, + #[cfg(feature = "dtype-array")] + AExpr::Function { + function: FunctionExpr::ArrayExpr(ArrayFunction::Get(false)), + .. + } => true, ae => ae.groups_sensitive(), } { false @@ -185,31 +204,6 @@ fn check_and_extend_predicate_pd_nodes( } } -/// An expression blocks predicates from being pushed past it if its results for -/// the subset where the predicate evaluates as true becomes different compared -/// to if it was performed before the predicate was applied. This is in general -/// any expression that produces outputs based on groups of values -/// (i.e. groups-wise) rather than individual values (i.e. element-wise). -/// -/// Examples of expressions whose results would change, and thus block push-down: -/// - any aggregation - sum, mean, first, last, min, max etc. -/// - sorting - as the sort keys would change between filters -pub(super) fn aexpr_blocks_predicate_pushdown(node: Node, expr_arena: &Arena) -> bool { - let mut stack = Vec::::with_capacity(4); - stack.push(node); - - // Cannot use `has_aexpr` because we need to ignore any literals in the RHS - // of an `is_in` operation. - while let Some(node) = stack.pop() { - let ae = expr_arena.get(node); - - if !check_and_extend_predicate_pd_nodes(&mut stack, ae, expr_arena) { - return true; - } - } - false -} - /// * `col(A).alias(B).alias(C) => (C, A)` /// * `col(A) => (A, A)` /// * `col(A).sum().alias(B) => None` @@ -240,6 +234,7 @@ pub enum PushdownEligibility { #[allow(clippy::type_complexity)] pub fn pushdown_eligibility( projection_nodes: &[ExprIR], + new_predicates: &[(Arc, ExprIR)], acc_predicates: &PlHashMap, ExprIR>, expr_arena: &mut Arena, ) -> PolarsResult<(PushdownEligibility, PlHashMap, Arc>)> { @@ -376,7 +371,7 @@ pub fn pushdown_eligibility( common_window_inputs = new; } - for e in acc_predicates.values() { + for (_, e) in new_predicates.iter() { debug_assert!(ae_nodes_stack.is_empty()); ae_nodes_stack.push(e.node()); @@ -447,13 +442,3 @@ pub fn pushdown_eligibility( _ => Ok((PushdownEligibility::Partial { to_local }, alias_to_col_map)), } } - -/// Used in places that previously handled blocking exprs before refactoring. -/// Can probably be eventually removed if it isn't catching anything. -#[inline(always)] -pub(super) fn debug_assert_aexpr_allows_predicate_pushdown(node: Node, expr_arena: &Arena) { - debug_assert!( - !aexpr_blocks_predicate_pushdown(node, expr_arena), - "Predicate pushdown: Did not expect blocking exprs at this point, please open an issue." - ); -} diff --git a/py-polars/tests/unit/test_predicates.py b/py-polars/tests/unit/test_predicates.py index 7009532ad304..49141fd282de 100644 --- a/py-polars/tests/unit/test_predicates.py +++ b/py-polars/tests/unit/test_predicates.py @@ -495,3 +495,19 @@ def test_predicate_push_down_with_alias_15442() -> None: .collect(predicate_pushdown=True) ) assert output.to_dict(as_series=False) == {"a": [1]} + + +def test_predicate_push_down_list_gather_17492() -> None: + lf = pl.LazyFrame({"val": [[1], [1, 1]], "len": [1, 2]}) + + assert_frame_equal( + lf.filter(pl.col("len") == 2).filter(pl.col("val").list.get(1) == 1), + lf.slice(1, 1), + ) + + # null_on_oob=True can pass + assert "FILTER" not in ( + lf.filter(pl.col("len") == 2) + .filter(pl.col("val").list.get(1, null_on_oob=True) == 1) + .explain() + )