Skip to content

Commit

Permalink
fix: Fix predicate pushdown for .list.(get|gather) (pola-rs#17511)
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Jul 9, 2024
1 parent 3629ea2 commit a5724d6
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 61 deletions.
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};

#[cfg(feature = "dtype-array")]
pub(super) use array::ArrayFunction;
pub(crate) use array::ArrayFunction;
#[cfg(feature = "cov")]
pub(crate) use correlation::CorrelationMethod;
#[cfg(feature = "fused")]
pub(crate) use fused::FusedOperator;
pub(super) use list::ListFunction;
pub(crate) use list::ListFunction;
use polars_core::prelude::*;
#[cfg(feature = "random")]
pub(crate) use random::RandomMethod;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,6 @@ pub(super) fn process_join(
let mut filter_left = false;
let mut filter_right = false;

debug_assert_aexpr_allows_predicate_pushdown(predicate.node(), expr_arena);

if !block_pushdown_left && check_input_node(predicate.node(), &schema_left, expr_arena) {
insert_and_combine_predicate(&mut pushdown_left, &predicate, expr_arena);
filter_left = true;
Expand Down
44 changes: 23 additions & 21 deletions crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl<'a> PredicatePushDown<'a> {
let input = inputs[inputs.len() - 1];

let (eligibility, alias_rename_map) =
pushdown_eligibility(&exprs, &acc_predicates, expr_arena)?;
pushdown_eligibility(&exprs, &[], &acc_predicates, expr_arena)?;

let local_predicates = match eligibility {
PushdownEligibility::Full => vec![],
Expand Down Expand Up @@ -265,22 +265,28 @@ impl<'a> PredicatePushDown<'a> {
let tmp_key = Arc::<str>::from(&*temporary_unique_key(&acc_predicates));
acc_predicates.insert(tmp_key.clone(), predicate.clone());

let local_predicates =
match pushdown_eligibility(&[], &acc_predicates, expr_arena)?.0 {
PushdownEligibility::Full => vec![],
PushdownEligibility::Partial { to_local } => {
let mut out = Vec::with_capacity(to_local.len());
for key in to_local {
out.push(acc_predicates.remove(&key).unwrap());
}
out
},
PushdownEligibility::NoPushdown => {
let out = acc_predicates.drain().map(|t| t.1).collect();
acc_predicates.clear();
out
},
};
let local_predicates = match pushdown_eligibility(
&[],
&[(tmp_key.clone(), predicate.clone())],
&acc_predicates,
expr_arena,
)?
.0
{
PushdownEligibility::Full => vec![],
PushdownEligibility::Partial { to_local } => {
let mut out = Vec::with_capacity(to_local.len());
for key in to_local {
out.push(acc_predicates.remove(&key).unwrap());
}
out
},
PushdownEligibility::NoPushdown => {
let out = acc_predicates.drain().map(|t| t.1).collect();
acc_predicates.clear();
out
},
};

if let Some(predicate) = acc_predicates.remove(&tmp_key) {
insert_and_combine_predicate(&mut acc_predicates, &predicate, expr_arena);
Expand Down Expand Up @@ -327,10 +333,6 @@ impl<'a> PredicatePushDown<'a> {
file_options: options,
output_schema,
} => {
for e in acc_predicates.values() {
debug_assert_aexpr_allows_predicate_pushdown(e.node(), expr_arena);
}

let local_predicates = match &scan_type {
#[cfg(feature = "parquet")]
FileScan::Parquet { .. } => vec![],
Expand Down
57 changes: 21 additions & 36 deletions crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ where
local_predicates
}

/// Extends a stack of nodes with new nodes from `ae` (with some filtering), to support traversing
/// an expression tree to check predicate PD eligibility. Generally called repeatedly with the same
/// stack until all nodes are exhausted.
fn check_and_extend_predicate_pd_nodes(
stack: &mut Vec<Node>,
ae: &AExpr,
Expand All @@ -148,6 +151,22 @@ fn check_and_extend_predicate_pd_nodes(
// rely on the height of the dataframe at this level and thus need
// to block pushdown.
AExpr::Literal(lit) => !lit.projects_as_scalar(),
// Rows that go OOB on get/gather may be filtered out in earlier operations,
// so we don't push these down.
AExpr::Function {
function: FunctionExpr::ListExpr(ListFunction::Get(false)),
..
} => true,
#[cfg(feature = "list_gather")]
AExpr::Function {
function: FunctionExpr::ListExpr(ListFunction::Gather(false)),
..
} => true,
#[cfg(feature = "dtype-array")]
AExpr::Function {
function: FunctionExpr::ArrayExpr(ArrayFunction::Get(false)),
..
} => true,
ae => ae.groups_sensitive(),
} {
false
Expand Down Expand Up @@ -185,31 +204,6 @@ fn check_and_extend_predicate_pd_nodes(
}
}

/// An expression blocks predicates from being pushed past it if its results for
/// the subset where the predicate evaluates as true becomes different compared
/// to if it was performed before the predicate was applied. This is in general
/// any expression that produces outputs based on groups of values
/// (i.e. groups-wise) rather than individual values (i.e. element-wise).
///
/// Examples of expressions whose results would change, and thus block push-down:
/// - any aggregation - sum, mean, first, last, min, max etc.
/// - sorting - as the sort keys would change between filters
pub(super) fn aexpr_blocks_predicate_pushdown(node: Node, expr_arena: &Arena<AExpr>) -> bool {
let mut stack = Vec::<Node>::with_capacity(4);
stack.push(node);

// Cannot use `has_aexpr` because we need to ignore any literals in the RHS
// of an `is_in` operation.
while let Some(node) = stack.pop() {
let ae = expr_arena.get(node);

if !check_and_extend_predicate_pd_nodes(&mut stack, ae, expr_arena) {
return true;
}
}
false
}

/// * `col(A).alias(B).alias(C) => (C, A)`
/// * `col(A) => (A, A)`
/// * `col(A).sum().alias(B) => None`
Expand Down Expand Up @@ -240,6 +234,7 @@ pub enum PushdownEligibility {
#[allow(clippy::type_complexity)]
pub fn pushdown_eligibility(
projection_nodes: &[ExprIR],
new_predicates: &[(Arc<str>, ExprIR)],
acc_predicates: &PlHashMap<Arc<str>, ExprIR>,
expr_arena: &mut Arena<AExpr>,
) -> PolarsResult<(PushdownEligibility, PlHashMap<Arc<str>, Arc<str>>)> {
Expand Down Expand Up @@ -376,7 +371,7 @@ pub fn pushdown_eligibility(
common_window_inputs = new;
}

for e in acc_predicates.values() {
for (_, e) in new_predicates.iter() {
debug_assert!(ae_nodes_stack.is_empty());
ae_nodes_stack.push(e.node());

Expand Down Expand Up @@ -447,13 +442,3 @@ pub fn pushdown_eligibility(
_ => Ok((PushdownEligibility::Partial { to_local }, alias_to_col_map)),
}
}

/// Used in places that previously handled blocking exprs before refactoring.
/// Can probably be eventually removed if it isn't catching anything.
#[inline(always)]
pub(super) fn debug_assert_aexpr_allows_predicate_pushdown(node: Node, expr_arena: &Arena<AExpr>) {
debug_assert!(
!aexpr_blocks_predicate_pushdown(node, expr_arena),
"Predicate pushdown: Did not expect blocking exprs at this point, please open an issue."
);
}
16 changes: 16 additions & 0 deletions py-polars/tests/unit/test_predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,19 @@ def test_predicate_push_down_with_alias_15442() -> None:
.collect(predicate_pushdown=True)
)
assert output.to_dict(as_series=False) == {"a": [1]}


def test_predicate_push_down_list_gather_17492() -> None:
lf = pl.LazyFrame({"val": [[1], [1, 1]], "len": [1, 2]})

assert_frame_equal(
lf.filter(pl.col("len") == 2).filter(pl.col("val").list.get(1) == 1),
lf.slice(1, 1),
)

# null_on_oob=True can pass
assert "FILTER" not in (
lf.filter(pl.col("len") == 2)
.filter(pl.col("val").list.get(1, null_on_oob=True) == 1)
.explain()
)

0 comments on commit a5724d6

Please sign in to comment.