From f1e3ceb9e1cd7cef56bc4d942a00ed8d05859b45 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sat, 29 Aug 2020 18:56:52 -0400 Subject: [PATCH] fix: sql right join switching lhs and rhs --- siuba/sql/verbs.py | 27 ++++++++++++++++++++------- siuba/tests/test_verb_join.py | 4 ++-- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py index 8102ebe5..e680e4b2 100644 --- a/siuba/sql/verbs.py +++ b/siuba/sql/verbs.py @@ -798,13 +798,14 @@ def _case_when(__data, cases): from collections.abc import Mapping -def _joined_cols(left_cols, right_cols, on_keys, full = False): +def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): """Return labeled columns, according to selection rules for joins. Rules: 1. For join keys, keep left table's column 2. When keys have the same labels, add suffix """ + # TODO: remove sets, so uses stable ordering # when left and right cols have same name, suffix with _x / _y keep_right = set(right_cols.keys()) - set(on_keys.values()) @@ -813,15 +814,21 @@ def _joined_cols(left_cols, right_cols, on_keys, full = False): right_cols_no_keys = {k: right_cols[k] for k in keep_right} # for an outer join, have key columns coalesce values - if full: - left_cols = {**left_cols} + + left_cols = {**left_cols} + if how == "full": for lk, rk in on_keys.items(): col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) left_cols[lk] = col.label(lk) + elif how == "right": + for lk, rk in on_keys.items(): + # Make left key columns actually be right ones (which contain left + extra) + left_cols[lk] = right_cols[rk].label(lk) + # create labels ---- - l_labs = _relabeled_cols(left_cols, shared_labs, "_x") - r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, "_y") + l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) + r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) return l_labs + r_labs @@ -855,6 +862,7 @@ def _join(left, right, on = None, *args, how = "inner", sql_on = None): # switch joins, since sqlalchemy doesn't have right join arg # see https://stackoverflow.com/q/11400307/1144523 left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} # create join conditions ---- bool_clause = _create_join_conds(left_sel, right_sel, on) @@ -866,7 +874,12 @@ def _join(left, right, on = None, *args, how = "inner", sql_on = None): isouter = how != "inner", full = how == "full" ) - + + # if right join, set selects back + if how == "right": + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + # note, shared_keys assumes on is a mapping... # TODO: shared_keys appears to be for when on is not specified, but was unused #shared_keys = [k for k,v in on.items() if k == v] @@ -874,7 +887,7 @@ def _join(left, right, on = None, *args, how = "inner", sql_on = None): left_sel.columns, right_sel.columns, on_keys = consolidate_keys, - full = how == "full" + how = how ) sel = sql.select(labeled_cols, from_obj = join) diff --git a/siuba/tests/test_verb_join.py b/siuba/tests/test_verb_join.py index bd09db4a..97e9dbc8 100644 --- a/siuba/tests/test_verb_join.py +++ b/siuba/tests/test_verb_join.py @@ -139,8 +139,8 @@ def test_basic_left_join(df1, df2): @backend_sql("TODO: pandas returns columns in rev name order") def test_basic_right_join(backend, df1, df2): # same as left join, but flip df arguments - out = right_join(df2, df1, {"ii": "ii"}) >> collect() - target = DF1.assign(y = ["a", "b", None, None]) + out = right_join(df1, df2, {"ii": "ii"}) >> collect() + target = DF2.assign(x = ["a", "b", None])[["ii", "x", "y"]] assert_frame_sort_equal(out, target) def test_basic_inner_join(df1, df2):