From 59d25294c646aef0c228c4c93870695aa1fff9c1 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sun, 30 Jun 2024 15:09:17 +0200 Subject: [PATCH] raise if keys passed to cross join (#17305) --- .../src/plans/conversion/dsl_to_ir.rs | 52 ++++++++++--------- py-polars/polars/lazyframe/frame.py | 3 ++ .../tests/unit/operations/test_cross_join.py | 9 ++++ 3 files changed, 40 insertions(+), 24 deletions(-) diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index 5cefebad7aac..30b3f9feb4ac 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -399,33 +399,37 @@ pub fn to_alp_impl( right_on, mut options, } => { - let mut turn_off_coalesce = false; - for e in left_on.iter().chain(right_on.iter()) { - if has_expr(e, |e| matches!(e, Expr::Alias(_, _))) { - polars_bail!( - ComputeError: - "'alias' is not allowed in a join key, use 'with_columns' first", - ) + if matches!(options.args.how, JoinType::Cross) { + polars_ensure!(left_on.len() + right_on.len() == 0, InvalidOperation: "a 'cross' join doesn't expect any join keys"); + } else { + let mut turn_off_coalesce = false; + for e in left_on.iter().chain(right_on.iter()) { + if has_expr(e, |e| matches!(e, Expr::Alias(_, _))) { + polars_bail!( + ComputeError: + "'alias' is not allowed in a join key, use 'with_columns' first", + ) + } + // Any expression that is not a simple column expression will turn of coalescing. + turn_off_coalesce |= has_expr(e, |e| !matches!(e, Expr::Column(_))); + } + if turn_off_coalesce { + let options = Arc::make_mut(&mut options); + options.args.coalesce = JoinCoalesce::KeepColumns; } - // Any expression that is not a simple column expression will turn of coalescing. - turn_off_coalesce |= has_expr(e, |e| !matches!(e, Expr::Column(_))); - } - if turn_off_coalesce { - let options = Arc::make_mut(&mut options); - options.args.coalesce = JoinCoalesce::KeepColumns; - } - options.args.validation.is_valid_join(&options.args.how)?; + options.args.validation.is_valid_join(&options.args.how)?; - polars_ensure!( - left_on.len() == right_on.len(), - ComputeError: - format!( - "the number of columns given as join key (left: {}, right:{}) should be equal", - left_on.len(), - right_on.len() - ) - ); + polars_ensure!( + left_on.len() == right_on.len(), + ComputeError: + format!( + "the number of columns given as join key (left: {}, right:{}) should be equal", + left_on.len(), + right_on.len() + ) + ); + } let input_left = to_alp_impl(owned(input_left), expr_arena, lp_arena, convert) .map_err(|e| e.context(failed_input!(join left)))?; diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index bdcf20c9ad4b..096e5e40f74e 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -4201,6 +4201,9 @@ def join( ) elif how == "cross": + if left_on is not None or right_on is not None: + msg = "cross join should not pass join keys" + raise ValueError(msg) return self._from_pyldf( self._ldf.join( other._ldf, diff --git a/py-polars/tests/unit/operations/test_cross_join.py b/py-polars/tests/unit/operations/test_cross_join.py index c83d0c3c7215..ddd3d36b226b 100644 --- a/py-polars/tests/unit/operations/test_cross_join.py +++ b/py-polars/tests/unit/operations/test_cross_join.py @@ -1,6 +1,8 @@ import sys from datetime import datetime +import pytest + from polars.dependencies import _ZONEINFO_AVAILABLE if sys.version_info >= (3, 9): @@ -44,3 +46,10 @@ def test_cross_join_predicate_pushdown_block_16956() -> None: datetime(2024, 6, 19, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")), ], } + + +def test_cross_join_raise_on_keys() -> None: + df = pl.DataFrame({"a": [0, 1], "b": ["x", "y"]}) + + with pytest.raises(ValueError): + df.join(df, how="cross", left_on="a", right_on="b")