diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py index 6b4731ca..8caeeec1 100644 --- a/siuba/sql/verbs.py +++ b/siuba/sql/verbs.py @@ -94,11 +94,14 @@ def exit(self, node): name = self._get_unique_name('win', self.window_cte.columns) label = col_expr.label(name) - # optionally put into CTE, and return its resulting column - self.window_cte.append_column(label) - win_col = self.window_cte.c.values()[-1] - + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + self.window_cte = self.window_cte.column(label) + win_col = lift_inner_cols(self.window_cte).values()[-1] self.windows.append(win_col) + return win_col return col_expr @@ -120,7 +123,7 @@ def _get_unique_name(prefix, columns): def track_call_windows(call, columns, group_by, order_by, window_cte = None): listener = WindowReplacer(columns, group_by, order_by, window_cte) col = listener.enter(call) - return col, listener.windows + return col, listener.windows, listener.window_cte def lift_inner_cols(tbl): @@ -474,7 +477,8 @@ def _filter(__data, *args): new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) #var_cols = new_call.op_vars(attr_calls = False) - col_expr, win_cols = __data.track_call_windows( + # note that a new win_sel is returned, w/ window columns appended + col_expr, win_cols, win_sel = __data.track_call_windows( new_call, sel.columns, window_cte = win_sel @@ -490,16 +494,20 @@ def _filter(__data, *args): # first cte, windows ---- if len(windows): + + for col in windows: + win_sel = win_sel.column(col) win_alias = win_sel.alias() + # because track_call_windows in the loop above used select.append_column # multiple times, sqlalchemy doesn't know our window columns are the ones # in the final mutated for of win_sel - col_key_map = {col.key: col for col in win_alias.columns.values()} - equivalents = {col: [col_key_map[col.key]] for col in windows} + #col_key_map = {col.key: col for col in win_alias.columns.values()} + #equivalents = {col: [col_key_map[col.key]] for col in windows} # move non-window functions to refer to win_sel clause (not the innermost) --- - bool_clause = sql.util.ClauseAdapter(win_alias, equivalents = equivalents) \ + bool_clause = sql.util.ClauseAdapter(win_alias) \ .traverse(bool_clause) orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] @@ -554,7 +562,7 @@ def _mutate_select(sel, colname, func, labs, __data): sel = cte.select() # evaluate call expr on columns, making sure to use group vars - new_col, windows = __data.track_call_windows(func, columns) + new_col, windows, _ = __data.track_call_windows(func, columns) # replacing an existing column, so strip it from select statement if replace_col: @@ -749,13 +757,13 @@ def _group_by(__data, *args, add = False, **kwargs): else: data = __data - cols = data.last_op.columns - # put kwarg grouping vars last, so similar order to function call groups = tuple(simple_varname(arg) for arg in args) + tuple(kwargs) if None in groups: raise NotImplementedError("Complex expressions not supported in sql group_by") + # ensure group_by variables are in the select columns + cols = data.last_op.alias().columns unmatched = set(groups) - set(cols.keys()) if unmatched: raise KeyError("group_by specifies columns missing from table: %s" %unmatched) @@ -939,8 +947,11 @@ def _anti_join(left, right = None, on = None, *args, sql_on = None): # create inner join ---- #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) - not_exists = ~_sql_select([1]).select_from(right_sel).where(bool_clause).exists() - sel = left_sel.select().where(not_exists) + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + sel = left_sel.select().where(~sql.exists(exists_clause)) return left.append_op(sel, order_by = tuple())