Skip to content

Commit

Permalink
fix: sqlalchemy 1.3 backwards compat for anti_join, filter
Browse files Browse the repository at this point in the history
  • Loading branch information
machow committed May 6, 2021
1 parent bdac387 commit ab780ed
Showing 1 changed file with 25 additions and 14 deletions.
39 changes: 25 additions & 14 deletions siuba/sql/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,14 @@ def exit(self, node):
name = self._get_unique_name('win', self.window_cte.columns)
label = col_expr.label(name)

# optionally put into CTE, and return its resulting column
self.window_cte.append_column(label)
win_col = self.window_cte.c.values()[-1]

# put into CTE, and return its resulting column, so that subsequent
# operations will refer to the window column on window_cte. Note that
# the operations will use the actual column, so may need to use the
# ClauseAdaptor to make it a reference to the label
self.window_cte = self.window_cte.column(label)
win_col = lift_inner_cols(self.window_cte).values()[-1]
self.windows.append(win_col)

return win_col

return col_expr
Expand All @@ -120,7 +123,7 @@ def _get_unique_name(prefix, columns):
def track_call_windows(call, columns, group_by, order_by, window_cte = None):
listener = WindowReplacer(columns, group_by, order_by, window_cte)
col = listener.enter(call)
return col, listener.windows
return col, listener.windows, listener.window_cte


def lift_inner_cols(tbl):
Expand Down Expand Up @@ -474,7 +477,8 @@ def _filter(__data, *args):
new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii)
#var_cols = new_call.op_vars(attr_calls = False)

col_expr, win_cols = __data.track_call_windows(
# note that a new win_sel is returned, w/ window columns appended
col_expr, win_cols, win_sel = __data.track_call_windows(
new_call,
sel.columns,
window_cte = win_sel
Expand All @@ -490,16 +494,20 @@ def _filter(__data, *args):

# first cte, windows ----
if len(windows):

for col in windows:
win_sel = win_sel.column(col)
win_alias = win_sel.alias()


# because track_call_windows in the loop above used select.append_column
# multiple times, sqlalchemy doesn't know our window columns are the ones
# in the final mutated for of win_sel
col_key_map = {col.key: col for col in win_alias.columns.values()}
equivalents = {col: [col_key_map[col.key]] for col in windows}
#col_key_map = {col.key: col for col in win_alias.columns.values()}
#equivalents = {col: [col_key_map[col.key]] for col in windows}

# move non-window functions to refer to win_sel clause (not the innermost) ---
bool_clause = sql.util.ClauseAdapter(win_alias, equivalents = equivalents) \
bool_clause = sql.util.ClauseAdapter(win_alias) \
.traverse(bool_clause)

orig_cols = [win_alias.columns[k] for k in sel.columns.keys()]
Expand Down Expand Up @@ -554,7 +562,7 @@ def _mutate_select(sel, colname, func, labs, __data):
sel = cte.select()

# evaluate call expr on columns, making sure to use group vars
new_col, windows = __data.track_call_windows(func, columns)
new_col, windows, _ = __data.track_call_windows(func, columns)

# replacing an existing column, so strip it from select statement
if replace_col:
Expand Down Expand Up @@ -749,13 +757,13 @@ def _group_by(__data, *args, add = False, **kwargs):
else:
data = __data

cols = data.last_op.columns

# put kwarg grouping vars last, so similar order to function call
groups = tuple(simple_varname(arg) for arg in args) + tuple(kwargs)
if None in groups:
raise NotImplementedError("Complex expressions not supported in sql group_by")

# ensure group_by variables are in the select columns
cols = data.last_op.alias().columns
unmatched = set(groups) - set(cols.keys())
if unmatched:
raise KeyError("group_by specifies columns missing from table: %s" %unmatched)
Expand Down Expand Up @@ -939,8 +947,11 @@ def _anti_join(left, right = None, on = None, *args, sql_on = None):

# create inner join ----
#not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause)
not_exists = ~_sql_select([1]).select_from(right_sel).where(bool_clause).exists()
sel = left_sel.select().where(not_exists)
exists_clause = _sql_select([sql.literal(1)]) \
.select_from(right_sel) \
.where(bool_clause)

sel = left_sel.select().where(~sql.exists(exists_clause))

return left.append_op(sel, order_by = tuple())

Expand Down

0 comments on commit ab780ed

Please sign in to comment.