Skip to content

Commit

Permalink
raise if keys passed to cross join (pola-rs#17305)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jun 30, 2024
1 parent 7392dc3 commit 59d2529
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 24 deletions.
52 changes: 28 additions & 24 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,33 +399,37 @@ pub fn to_alp_impl(
right_on,
mut options,
} => {
let mut turn_off_coalesce = false;
for e in left_on.iter().chain(right_on.iter()) {
if has_expr(e, |e| matches!(e, Expr::Alias(_, _))) {
polars_bail!(
ComputeError:
"'alias' is not allowed in a join key, use 'with_columns' first",
)
if matches!(options.args.how, JoinType::Cross) {
polars_ensure!(left_on.len() + right_on.len() == 0, InvalidOperation: "a 'cross' join doesn't expect any join keys");
} else {
let mut turn_off_coalesce = false;
for e in left_on.iter().chain(right_on.iter()) {
if has_expr(e, |e| matches!(e, Expr::Alias(_, _))) {
polars_bail!(
ComputeError:
"'alias' is not allowed in a join key, use 'with_columns' first",
)
}
// Any expression that is not a simple column expression will turn of coalescing.
turn_off_coalesce |= has_expr(e, |e| !matches!(e, Expr::Column(_)));
}
if turn_off_coalesce {
let options = Arc::make_mut(&mut options);
options.args.coalesce = JoinCoalesce::KeepColumns;
}
// Any expression that is not a simple column expression will turn of coalescing.
turn_off_coalesce |= has_expr(e, |e| !matches!(e, Expr::Column(_)));
}
if turn_off_coalesce {
let options = Arc::make_mut(&mut options);
options.args.coalesce = JoinCoalesce::KeepColumns;
}

options.args.validation.is_valid_join(&options.args.how)?;
options.args.validation.is_valid_join(&options.args.how)?;

polars_ensure!(
left_on.len() == right_on.len(),
ComputeError:
format!(
"the number of columns given as join key (left: {}, right:{}) should be equal",
left_on.len(),
right_on.len()
)
);
polars_ensure!(
left_on.len() == right_on.len(),
ComputeError:
format!(
"the number of columns given as join key (left: {}, right:{}) should be equal",
left_on.len(),
right_on.len()
)
);
}

let input_left = to_alp_impl(owned(input_left), expr_arena, lp_arena, convert)
.map_err(|e| e.context(failed_input!(join left)))?;
Expand Down
3 changes: 3 additions & 0 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4201,6 +4201,9 @@ def join(
)

elif how == "cross":
if left_on is not None or right_on is not None:
msg = "cross join should not pass join keys"
raise ValueError(msg)
return self._from_pyldf(
self._ldf.join(
other._ldf,
Expand Down
9 changes: 9 additions & 0 deletions py-polars/tests/unit/operations/test_cross_join.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import sys
from datetime import datetime

import pytest

from polars.dependencies import _ZONEINFO_AVAILABLE

if sys.version_info >= (3, 9):
Expand Down Expand Up @@ -44,3 +46,10 @@ def test_cross_join_predicate_pushdown_block_16956() -> None:
datetime(2024, 6, 19, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")),
],
}


def test_cross_join_raise_on_keys() -> None:
df = pl.DataFrame({"a": [0, 1], "b": ["x", "y"]})

with pytest.raises(ValueError):
df.join(df, how="cross", left_on="a", right_on="b")

0 comments on commit 59d2529

Please sign in to comment.