Skip to content

Commit

Permalink
Merge pull request #278 from machow/fix-sql-summarize-subqueries
Browse files Browse the repository at this point in the history
Fix sql summarize subqueries
  • Loading branch information
machow authored Aug 29, 2020
2 parents a5ec8be + ef0c5dc commit a2d2239
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 5 deletions.
12 changes: 9 additions & 3 deletions siuba/sql/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ def col_expr_requires_cte(call, sel, is_mutate = False):

call_vars = set(call.op_vars(attr_calls = False))

columns = lift_inner_cols(sel)
sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label))
sel_labs = get_inner_labels(sel)

# I use the acronym fwg sol (frog soul) to remember sql clause eval order
# from, where, group by, select, order by, limit
Expand All @@ -147,6 +146,11 @@ def col_expr_requires_cte(call, sel, is_mutate = False):
or not sel_labs.isdisjoint(call_vars)
)

def get_inner_labels(sel):
columns = lift_inner_cols(sel)
sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label))
return sel_labs

def get_missing_columns(call, columns):
missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys())
return missing_cols
Expand Down Expand Up @@ -465,6 +469,7 @@ def _filter(__data, *args):
conds = []
windows = []
for ii, arg in enumerate(args):

if isinstance(arg, Call):
new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii)
#var_cols = new_call.op_vars(attr_calls = False)
Expand Down Expand Up @@ -690,9 +695,10 @@ def _summarize(__data, **kwargs):
)

needs_cte = [col_expr_requires_cte(call, sel) for call in new_calls.values()]
group_on_labels = set(__data.group_by) & get_inner_labels(sel)

# create select statement ----
if any(needs_cte):
if any(needs_cte) or len(group_on_labels):
# need a cte, due to alias cols or existing group by
# current select stmt has group by clause, so need to make it subquery
cte = sel.alias()
Expand Down
32 changes: 30 additions & 2 deletions siuba/tests/test_verb_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
https://github.com/tidyverse/dbplyr/blob/master/tests/testthat/test-verb-mutate.R
"""

from siuba import _, mutate, select, group_by, summarize, filter
from siuba import _, mutate, select, group_by, summarize, filter, show_query
from siuba.dply.vector import row_number, n

import pytest
Expand Down Expand Up @@ -35,7 +35,7 @@ def test_summarize_ungrouped(df, query, output):


@pytest.mark.skip("TODO: should return 1 row (#63)")
def test_ungrouped_summarize_literal(df, query, output):
def test_ungrouped_summarize_literal(df):
assert_equal_query(df, summarize(y = 1), data_frame(y = 1))


Expand Down Expand Up @@ -121,3 +121,31 @@ def test_summarize_removes_series_index():
df.assign(res = df.x + df.y).drop(columns = ["x", "y"])
)


@backend_sql
def test_summarize_subquery_group_vars(backend, df):
query = mutate(g2 = _.g.str.upper()) >> group_by(_.g2) >> summarize(low = _.x.min())
assert_equal_query(
df,
query,
data_frame(g2 = ['A', 'B'], low = [1, 3])
)

# check that is uses a subquery, since g2 is defined in first query
text = str(query(df).last_op)
assert text.count('FROM') == 2


@backend_sql
def test_summarize_subquery_op_vars(backend, df):
query = mutate(x2 = _.x + 1) >> group_by(_.g) >> summarize(low = _.x2.min())
assert_equal_query(
df,
query,
data_frame(g = ['a', 'b'], low = [2, 4])
)

# check that is uses a subquery, since x2 is defined in first query
text = str(query(df).last_op)
assert text.count('FROM') == 2

0 comments on commit a2d2239

Please sign in to comment.