Skip to content

Commit

Permalink
fix: Don't pushdown predicates in cross join if the refer to both tab…
Browse files Browse the repository at this point in the history
…les (#16983)
  • Loading branch information
ritchie46 authored Jun 16, 2024
1 parent 1effc24 commit 73cc712
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 2 deletions.
4 changes: 4 additions & 0 deletions crates/polars-plan/src/logical_plan/optimizer/join_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pub(super) fn split_suffix<'a>(name: &'a str, suffix: &str) -> &'a str {
let (original, _) = name.split_at(name.len() - suffix.len());
original
}
1 change: 1 addition & 0 deletions crates/polars-plan/src/logical_plan/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ mod cse;
mod flatten_union;
#[cfg(feature = "fused")]
mod fused;
mod join_utils;
mod predicate_pushdown;
mod projection_pushdown;
mod simplify_expr;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use crate::logical_plan::optimizer::join_utils::split_suffix;

// Information concerning individual sides of a join.
#[derive(PartialEq, Eq)]
Expand Down Expand Up @@ -96,6 +97,28 @@ fn all_pred_cols_in_left_on(
})
}

// Checks if a predicate refers to columns in both tables
fn predicate_applies_to_both_tables(
predicate: Node,
expr_arena: &Arena<AExpr>,
schema_left: &Schema,
schema_right: &Schema,
suffix: &str,
) -> bool {
let mut left_used = false;
let mut right_used = false;
for name in aexpr_to_leaf_names_iter(predicate, expr_arena) {
if schema_left.contains(name.as_ref()) {
left_used |= true;
} else {
right_used |= schema_right.contains(name.as_ref())
|| name.ends_with(suffix)
&& schema_right.contains(split_suffix(name.as_ref(), suffix))
}
}
left_used && right_used
}

#[allow(clippy::too_many_arguments)]
pub(super) fn process_join(
opt: &PredicatePushDown,
Expand Down Expand Up @@ -128,6 +151,20 @@ pub(super) fn process_join(
let mut local_predicates = Vec::with_capacity(acc_predicates.len());

for (_, predicate) in acc_predicates {
// Cross joins produce a cartesian product, so if a predicate combines columns from both tables, we should not push down.
if matches!(options.args.how, JoinType::Cross)
&& predicate_applies_to_both_tables(
predicate.node(),
expr_arena,
&schema_left,
&schema_right,
options.args.suffix(),
)
{
local_predicates.push(predicate);
continue;
}

// check if predicate can pass the joins node
let block_pushdown_left = has_aexpr(predicate.node(), expr_arena, |ae| {
should_block_join_specific(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use std::borrow::Cow;

use super::*;
use crate::prelude::optimizer::join_utils::split_suffix;

fn add_keys_to_accumulated_state(
expr: Node,
Expand Down Expand Up @@ -414,8 +415,7 @@ fn process_projection(
// suffix.
if leaf_column_name.ends_with(suffix) && join_schema.contains(leaf_column_name.as_ref()) {
// downwards name is the name without the _right i.e. "foo".
let (downwards_name, _) =
leaf_column_name.split_at(leaf_column_name.len() - suffix.len());
let downwards_name = split_suffix(leaf_column_name.as_ref(), suffix);

let downwards_name_column = expr_arena.add(AExpr::Column(Arc::from(downwards_name)));
// project downwards and locally immediately alias to prevent wrong projections
Expand Down
44 changes: 44 additions & 0 deletions py-polars/tests/unit/operations/test_cross_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import sys
from datetime import datetime

from polars.dependencies import _ZONEINFO_AVAILABLE

if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
elif _ZONEINFO_AVAILABLE:
# Import from submodule due to typing issue with backports.zoneinfo package:
# https://github.com/pganssle/zoneinfo/issues/125
from backports.zoneinfo._zoneinfo import ZoneInfo

import polars as pl


def test_cross_join_predicate_pushdown_block_16956() -> None:
lf = pl.LazyFrame(
[
[1718085600000, 1718172000000, 1718776800000],
[1718114400000, 1718200800000, 1718805600000],
],
schema=["start_datetime", "end_datetime"],
).cast(pl.Datetime("ms", "Europe/Amsterdam"))

assert (
lf.join(lf, on="start_datetime", how="full")
.filter(
pl.col.end_datetime_right.is_between(
pl.col.start_datetime, pl.col.start_datetime.dt.offset_by("132h")
)
)
.select("start_datetime", "end_datetime_right")
).collect(predicate_pushdown=True).to_dict(as_series=False) == {
"start_datetime": [
datetime(2024, 6, 11, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
datetime(2024, 6, 12, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
datetime(2024, 6, 19, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
],
"end_datetime_right": [
datetime(2024, 6, 11, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
datetime(2024, 6, 12, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
datetime(2024, 6, 19, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
],
}

0 comments on commit 73cc712

Please sign in to comment.