From 2685a8668beb8f0e38518a9cd00fb9aa6173ee35 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 27 Dec 2024 09:24:18 +0100 Subject: [PATCH] fix: Validate asof join by args in IR resolving phase (#20473) --- .../polars-plan/src/plans/conversion/join.rs | 36 +++++++++++++------ py-polars/polars/lazyframe/frame.py | 7 ++-- .../tests/unit/operations/test_join_asof.py | 18 ++++++++++ 3 files changed, 48 insertions(+), 13 deletions(-) diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 6eff0f951338..9c973d446ead 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -43,6 +43,16 @@ pub fn resolve_join( } let owned = Arc::unwrap_or_clone; + let mut input_left = input_left.map_right(Ok).right_or_else(|input| { + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join left))) + })?; + let mut input_right = input_right.map_right(Ok).right_or_else(|input| { + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join right))) + })?; + + let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena); + let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena); + if options.args.how.is_cross() { polars_ensure!(left_on.len() + right_on.len() == 0, InvalidOperation: "a 'cross' join doesn't expect any join keys"); } else { @@ -65,6 +75,21 @@ pub fn resolve_join( options.args.validation.is_valid_join(&options.args.how)?; + #[cfg(feature = "asof_join")] + if let JoinType::AsOf(opt) = &options.args.how { + match (&opt.left_by, &opt.right_by) { + (None, None) => {}, + (Some(l), Some(r)) => { + polars_ensure!(l.len() == r.len(), InvalidOperation: "expected equal number of columns in 'by_left' and 'by_right' in 'asof_join'"); + validate_columns_in_input(l, &schema_left, "asof_join")?; + validate_columns_in_input(r, &schema_right, "asof_join")?; + }, + _ => { + polars_bail!(InvalidOperation: "expected both 'by_left' and 'by_right' to be set in 'asof_join'") + }, + } + } + polars_ensure!( left_on.len() == right_on.len(), InvalidOperation: @@ -76,16 +101,6 @@ pub fn resolve_join( ); } - let mut input_left = input_left.map_right(Ok).right_or_else(|input| { - to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join left))) - })?; - let mut input_right = input_right.map_right(Ok).right_or_else(|input| { - to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join right))) - })?; - - let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena); - let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena); - let schema = det_join_schema(&schema_left, &schema_right, &left_on, &right_on, &options) .map_err(|e| e.context(failed_here!(join schema resolving)))?; @@ -120,6 +135,7 @@ pub fn resolve_join( .coerce_types(ctxt.expr_arena, ctxt.lp_arena, input_right) .map_err(|e| e.context("'join' failed".into()))?; + // Re-evaluate because of mutable borrows earlier. let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena); let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena); diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 887a7a63faa8..de33fe1bdcfc 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -4568,9 +4568,10 @@ def join_asof( if by is not None: by_left_ = [by] if isinstance(by, str) else by by_right_ = by_left_ - elif (by_left is not None) and (by_right is not None): - by_left_ = [by_left] if isinstance(by_left, str) else by_left - by_right_ = [by_right] if isinstance(by_right, str) else by_right + elif (by_left is not None) or (by_right is not None): + by_left_ = [by_left] if isinstance(by_left, str) else by_left # type: ignore[assignment] + by_right_ = [by_right] if isinstance(by_right, str) else by_right # type: ignore[assignment] + else: # no by by_left_ = None diff --git a/py-polars/tests/unit/operations/test_join_asof.py b/py-polars/tests/unit/operations/test_join_asof.py index 61cd067a1e02..aa75c3eae313 100644 --- a/py-polars/tests/unit/operations/test_join_asof.py +++ b/py-polars/tests/unit/operations/test_join_asof.py @@ -1196,3 +1196,21 @@ def test_asof_join_by_schema() -> None: ) assert q.collect_schema() == q.collect().schema + + +def test_raise_invalid_by_arg_13020() -> None: + df1 = pl.DataFrame({"asOfDate": [date(2020, 1, 1)]}) + df2 = pl.DataFrame( + { + "endityId": [date(2020, 1, 1)], + "eventDate": ["A"], + } + ) + with pytest.raises(pl.exceptions.InvalidOperationError, match="expected both"): + df1.sort("asOfDate").join_asof( + df2.sort("eventDate"), + left_on="asOfDate", + right_on="eventDate", + by_left=None, + by_right=["entityId"], + )