Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: sql right join switching lhs and rhs #279

Merged
merged 1 commit into from
Aug 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions siuba/sql/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,13 +798,14 @@ def _case_when(__data, cases):

from collections.abc import Mapping

def _joined_cols(left_cols, right_cols, on_keys, full = False):
def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")):
"""Return labeled columns, according to selection rules for joins.

Rules:
1. For join keys, keep left table's column
2. When keys have the same labels, add suffix
"""

# TODO: remove sets, so uses stable ordering
# when left and right cols have same name, suffix with _x / _y
keep_right = set(right_cols.keys()) - set(on_keys.values())
Expand All @@ -813,15 +814,21 @@ def _joined_cols(left_cols, right_cols, on_keys, full = False):
right_cols_no_keys = {k: right_cols[k] for k in keep_right}

# for an outer join, have key columns coalesce values
if full:
left_cols = {**left_cols}

left_cols = {**left_cols}
if how == "full":
for lk, rk in on_keys.items():
col = sql.functions.coalesce(left_cols[lk], right_cols[rk])
left_cols[lk] = col.label(lk)
elif how == "right":
for lk, rk in on_keys.items():
# Make left key columns actually be right ones (which contain left + extra)
left_cols[lk] = right_cols[rk].label(lk)


# create labels ----
l_labs = _relabeled_cols(left_cols, shared_labs, "_x")
r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, "_y")
l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0])
r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1])

return l_labs + r_labs

Expand Down Expand Up @@ -855,6 +862,7 @@ def _join(left, right, on = None, *args, how = "inner", sql_on = None):
# switch joins, since sqlalchemy doesn't have right join arg
# see https://stackoverflow.com/q/11400307/1144523
left_sel, right_sel = right_sel, left_sel
on = {v:k for k,v in on.items()}

# create join conditions ----
bool_clause = _create_join_conds(left_sel, right_sel, on)
Expand All @@ -866,15 +874,20 @@ def _join(left, right, on = None, *args, how = "inner", sql_on = None):
isouter = how != "inner",
full = how == "full"
)


# if right join, set selects back
if how == "right":
left_sel, right_sel = right_sel, left_sel
on = {v:k for k,v in on.items()}

# note, shared_keys assumes on is a mapping...
# TODO: shared_keys appears to be for when on is not specified, but was unused
#shared_keys = [k for k,v in on.items() if k == v]
labeled_cols = _joined_cols(
left_sel.columns,
right_sel.columns,
on_keys = consolidate_keys,
full = how == "full"
how = how
)

sel = sql.select(labeled_cols, from_obj = join)
Expand Down
4 changes: 2 additions & 2 deletions siuba/tests/test_verb_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def test_basic_left_join(df1, df2):
@backend_sql("TODO: pandas returns columns in rev name order")
def test_basic_right_join(backend, df1, df2):
# same as left join, but flip df arguments
out = right_join(df2, df1, {"ii": "ii"}) >> collect()
target = DF1.assign(y = ["a", "b", None, None])
out = right_join(df1, df2, {"ii": "ii"}) >> collect()
target = DF2.assign(x = ["a", "b", None])[["ii", "x", "y"]]
assert_frame_sort_equal(out, target)

def test_basic_inner_join(df1, df2):
Expand Down