diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py index 18ce57bd..8102ebe5 100644 --- a/siuba/sql/verbs.py +++ b/siuba/sql/verbs.py @@ -134,8 +134,7 @@ def col_expr_requires_cte(call, sel, is_mutate = False): call_vars = set(call.op_vars(attr_calls = False)) - columns = lift_inner_cols(sel) - sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) + sel_labs = get_inner_labels(sel) # I use the acronym fwg sol (frog soul) to remember sql clause eval order # from, where, group by, select, order by, limit @@ -147,6 +146,11 @@ def col_expr_requires_cte(call, sel, is_mutate = False): or not sel_labs.isdisjoint(call_vars) ) +def get_inner_labels(sel): + columns = lift_inner_cols(sel) + sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) + return sel_labs + def get_missing_columns(call, columns): missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) return missing_cols @@ -465,6 +469,7 @@ def _filter(__data, *args): conds = [] windows = [] for ii, arg in enumerate(args): + if isinstance(arg, Call): new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) #var_cols = new_call.op_vars(attr_calls = False) @@ -690,9 +695,10 @@ def _summarize(__data, **kwargs): ) needs_cte = [col_expr_requires_cte(call, sel) for call in new_calls.values()] + group_on_labels = set(__data.group_by) & get_inner_labels(sel) # create select statement ---- - if any(needs_cte): + if any(needs_cte) or len(group_on_labels): # need a cte, due to alias cols or existing group by # current select stmt has group by clause, so need to make it subquery cte = sel.alias() diff --git a/siuba/tests/test_verb_summarize.py b/siuba/tests/test_verb_summarize.py index aaa80633..2e290e04 100644 --- a/siuba/tests/test_verb_summarize.py +++ b/siuba/tests/test_verb_summarize.py @@ -4,7 +4,7 @@ https://github.com/tidyverse/dbplyr/blob/master/tests/testthat/test-verb-mutate.R """ -from siuba import _, mutate, select, group_by, summarize, filter +from siuba import _, mutate, select, group_by, summarize, filter, show_query from siuba.dply.vector import row_number, n import pytest @@ -35,7 +35,7 @@ def test_summarize_ungrouped(df, query, output): @pytest.mark.skip("TODO: should return 1 row (#63)") -def test_ungrouped_summarize_literal(df, query, output): +def test_ungrouped_summarize_literal(df): assert_equal_query(df, summarize(y = 1), data_frame(y = 1)) @@ -121,3 +121,31 @@ def test_summarize_removes_series_index(): df.assign(res = df.x + df.y).drop(columns = ["x", "y"]) ) + +@backend_sql +def test_summarize_subquery_group_vars(backend, df): + query = mutate(g2 = _.g.str.upper()) >> group_by(_.g2) >> summarize(low = _.x.min()) + assert_equal_query( + df, + query, + data_frame(g2 = ['A', 'B'], low = [1, 3]) + ) + + # check that is uses a subquery, since g2 is defined in first query + text = str(query(df).last_op) + assert text.count('FROM') == 2 + + +@backend_sql +def test_summarize_subquery_op_vars(backend, df): + query = mutate(x2 = _.x + 1) >> group_by(_.g) >> summarize(low = _.x2.min()) + assert_equal_query( + df, + query, + data_frame(g = ['a', 'b'], low = [2, 4]) + ) + + # check that is uses a subquery, since x2 is defined in first query + text = str(query(df).last_op) + assert text.count('FROM') == 2 +