Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: summarize raising error when a grouping col is all NA (or mostly NA) #459

Merged
merged 5 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions siuba/dply/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _mutate_cols(__data, args, kwargs):


def _make_groupby_safe(gdf):
return gdf.obj.groupby(gdf.grouper, group_keys=False)
return gdf.obj.groupby(gdf.grouper, group_keys=False, dropna=False)


MSG_TYPE_ERROR = "The first argument to {func} must be one of: {types}"
Expand Down Expand Up @@ -363,9 +363,9 @@ def group_by(__data, *args, add = False, **kwargs):
# ensures group levels are recalculated if varname was in transmute
groupings[varname] = varname

return tmp_df.groupby(list(groupings.values()))
return tmp_df.groupby(list(groupings.values()), dropna=False, group_keys=True)

return tmp_df.groupby(by = by_vars)
return tmp_df.groupby(by = by_vars, dropna=False, group_keys=True)


@singledispatch2((pd.DataFrame, DataFrameGroupBy))
Expand Down Expand Up @@ -563,6 +563,19 @@ def summarize(__data, *args, **kwargs):

@summarize.register(DataFrameGroupBy)
def _summarize(__data, *args, **kwargs):
if __data.dropna or not __data.group_keys:
warnings.warn(
f"Grouped data passed to summarize must have dropna=False and group_keys=True."
" Regrouping with these arguments set."
)

if __data.grouper.dropna:
# will need to recalculate groupings, otherwise it ignores dropna
group_cols = [ping.name for ping in __data.grouper.groupings]
else:
group_cols = __data.grouper.groupings
__data = __data.obj.groupby(group_cols, dropna=False, group_keys=True)

df_summarize = summarize.registry[pd.DataFrame]

df = __data.apply(df_summarize, *args, **kwargs)
Expand Down
11 changes: 2 additions & 9 deletions siuba/tests/test_sql_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,10 @@ def test_raw_sql_mutate_grouped(backend, df):
)


@pytest.mark.skip_backend("snowflake") # supported by snowflake
@pytest.mark.skip_backend("snowflake", "duckdb") # they support this behavior
@backend_sql
def test_raw_sql_mutate_refer_previous_raise_dberror(backend, skip_backend, df):
# Note: unlikely will be able to support this case. Normally we analyze
if backend.name == "duckdb":
# duckdb dialect re-raises the engines exception, which is RuntimeError
# the expression to know whether we need to create a subquery.
import duckdb
exc = duckdb.BinderException
else:
exc = sqlalchemy.exc.DatabaseError
exc = sqlalchemy.exc.DatabaseError

with pytest.raises(exc):
assert_equal_query(
Expand Down
11 changes: 11 additions & 0 deletions siuba/tests/test_verb_mutate.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@ def test_mutate_reassign_all_cols_keeps_rowsize(dfs):
data_frame(a = [1,1,1], b = [2,2,2])
)


def test_mutate_grouped_pandas_no_dropna():
src = data_frame(x = [1, 2], g = [None, None])

assert_equal_query(
src,
group_by(_.g) >> mutate(res = _.x + 1),
data_frame(x = [1, 2], g = [None, None], res = [2, 3])
)


@backend_sql
def test_mutate_window_funcs(backend):
data = data_frame(idx = range(0, 4), x = range(1, 5), g = [1,1,2,2])
Expand Down
48 changes: 48 additions & 0 deletions siuba/tests/test_verb_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

https://github.com/tidyverse/dbplyr/blob/master/tests/testthat/test-verb-mutate.R
"""

import numpy as np

from siuba import _, mutate, select, group_by, summarize, filter, show_query, arrange
from siuba.dply.vector import row_number, n
Expand Down Expand Up @@ -47,6 +49,52 @@ def test_summarize_after_mutate_cuml_win(backend, df_float):
)


def test_summarize_keeps_na_grouping_cols(backend):
df = data_frame(x = [1, 2, 3], g = [None, None, None])
src = backend.load_df(df)

if backend.name == "pandas":
missing = np.nan
else:
missing = None

assert_equal_query(
src,
group_by(_.g) >> summarize(res = _.x.min()),
data_frame(g = [missing], res = [1])
)


def test_summarize_regroups_group_keys():
df = data_frame(x = [1, 2, 3], g = [None, None, None])

# bad group_keys choice
g_df = df.groupby("g", group_keys=False, dropna=False)

with pytest.warns(UserWarning, match="group_keys=True"):

assert_equal_query(
g_df,
summarize(res = _.x.min()),
data_frame(g = [np.nan], res = [1])
)


def test_summarize_regroups_dropna():
df = data_frame(x = [1, 2, 3], g = [None, None, None])

# bad dropna choice
g_df = df.groupby("g", group_keys=True, dropna=True)

with pytest.warns(UserWarning, match="dropna=False"):

assert_equal_query(
g_df,
summarize(res = _.x.min()),
data_frame(g = [np.nan], res = [1])
)


@backend_sql
def test_summarize_keeps_group_vars(backend, gdf):
q = gdf >> summarize(n = n(_))
Expand Down