Skip to content

Commit

Permalink
fix(pandas,sql): ensure verbs accept grouped data
Browse files Browse the repository at this point in the history
  • Loading branch information
machow committed Sep 27, 2022
1 parent cc00e71 commit 7d64101
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 44 deletions.
147 changes: 125 additions & 22 deletions siuba/dply/verbs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from functools import singledispatch
from functools import singledispatch, wraps
from pandas import DataFrame

import pandas as pd
import numpy as np
import warnings


from pandas.core.groupby import DataFrameGroupBy
from pandas.core.dtypes.inference import is_scalar
Expand Down Expand Up @@ -81,6 +84,21 @@ def _repr_grouped_df_console_(self):
return "(grouped data frame)\n" + repr(self.obj)


def _bounce_groupby(f):
@wraps(f)
def wrapper(__data: "pd.DataFrame | DataFrameGroupBy", *args, **kwargs):
if isinstance(__data, pd.DataFrame):
return f(__data, *args, **kwargs)

groupings = __data.grouper.groupings
group_cols = [ping.name for ping in groupings]

res = f(__data.obj, *args, **kwargs)

return res.groupby(group_cols)

return wrapper


def _regroup(df):
# try to regroup after an apply, when user kept index (e.g. group_keys = True)
Expand Down Expand Up @@ -538,6 +556,26 @@ def _transmute(__data, *args, **kwargs):

# Select ======================================================================

def _insert_missing_groups(dst, orig, missing_groups):
if missing_groups:
warnings.warn(f"Adding missing grouping variables: {missing_groups}")

for ii, colname in enumerate(missing_groups):
dst.insert(ii, colname, orig[colname])


def _select_group_renames(selection: dict, group_cols):
"""Returns a 2-tuple: groups missing in the select, new group keys."""
renamed = {k: v for k,v in selection.items() if v is not None}

sel_groups = [
renamed[colname] or colname for colname in group_cols if colname in renamed
]
missing_groups = [colname for colname in group_cols if colname not in selection]

return missing_groups, (*missing_groups, *sel_groups)


@singledispatch2(DataFrame)
def select(__data, *args, **kwargs):
"""Select columns of a table to keep or drop (and optionally rename).
Expand Down Expand Up @@ -610,10 +648,21 @@ def select(__data, *args, **kwargs):

return __data[list(od)].rename(columns = to_rename)


@select.register(DataFrameGroupBy)
def _select(__data, *args, **kwargs):
raise Exception("Selecting columns of grouped DataFrame currently not allowed")
# tidyselect
var_list = var_create(*args)
od = var_select(__data.obj.columns, *var_list)

group_cols = [ping.name for ping in __data.grouper.groupings]

res = select(__data.obj, *args, **kwargs)

missing_groups, group_keys = _select_group_renames(od, group_cols)
_insert_missing_groups(res, __data.obj, missing_groups)

return res.groupby(list(group_keys))


# Rename ======================================================================
Expand Down Expand Up @@ -654,10 +703,17 @@ def rename(__data, **kwargs):

return __data.rename(columns = col_names)


@rename.register(DataFrameGroupBy)
def _rename(__data, **kwargs):
raise NotImplementedError("Selecting columns of grouped DataFrame currently not allowed")
col_names = {simple_varname(v):k for k,v in kwargs.items()}
group_cols = [ping.name for ping in __data.grouper.groupings]

res = rename(__data.obj, **kwargs)

missing_groups, group_keys = _select_group_renames(col_names, group_cols)

return res.groupby(list(group_keys))


# Arrange =====================================================================
Expand Down Expand Up @@ -787,6 +843,19 @@ def _arrange(__data, *args):

# Distinct ====================================================================


def _var_select_simple(args) -> "dict[str, bool]":
"""Return an 'ordered set' of selected column names."""
cols = {simple_varname(x): True for x in args}
if None in cols:
raise Exception(
"Positional arguments must be simple column. "
"e.g. _.colname or _['colname']\n\n"
f"Received: {repr(cols[None])}"
)

return cols

@singledispatch2(DataFrame)
def distinct(__data, *args, _keep_all = False, **kwargs):
"""Keep only distinct (unique) rows from a table.
Expand Down Expand Up @@ -831,11 +900,7 @@ def distinct(__data, *args, _keep_all = False, **kwargs):
2 Chinstrap Dream 46.5 17.9
"""
# using dict as ordered set
cols = {simple_varname(x): True for x in args}
if None in cols:
raise Exception("positional arguments must be simple column, "
"e.g. _.colname or _['colname']"
)
cols = _var_select_simple(args)

# mutate kwargs
cols.update(kwargs)
Expand All @@ -850,12 +915,33 @@ def distinct(__data, *args, _keep_all = False, **kwargs):

return tmp_data


@distinct.register(DataFrameGroupBy)
def _distinct(__data, *args, _keep_all = False, **kwargs):
df = __data.apply(lambda x: distinct(x, *args, _keep_all = _keep_all, **kwargs))
return _regroup(df)

# if_else
cols = _var_select_simple(args)
cols.update(kwargs)

# special case: use all variables when none are specified
if not len(cols): cols = __data.columns

group_cols_ordered = {ping.name: True for ping in __data.grouper.groupings}
final_cols = list({**group_cols_ordered, **cols, **kwargs})

mutated = mutate(__data, **kwargs).obj

if not _keep_all:
pre_df = mutated[final_cols]
else:
pre_df = mutated

res = pre_df.drop_duplicates(list(final_cols)).reset_index(drop = True)
return res.groupby(list(group_cols_ordered))



# if_else, case_when ==========================================================

# TODO: move to vector.py
@singledispatch
def if_else(condition, true, false):
Expand Down Expand Up @@ -1101,7 +1187,7 @@ def count(__data, *args, wt = None, sort = False, **kwargs):
return counts


@singledispatch2(pd.DataFrame)
@singledispatch2((pd.DataFrame, DataFrameGroupBy))
def add_count(__data, *args, wt = None, sort = False, **kwargs):
"""Add a column that is the number of observations for each grouping of data.
Expand Down Expand Up @@ -1152,8 +1238,8 @@ def add_count(__data, *args, wt = None, sort = False, **kwargs):
"""
counts = count(__data, *args, wt = wt, sort = sort, **kwargs)

on = list(counts.columns)[:-1]
return __data.merge(counts, on = on)
by = list(counts.columns)[:-1]
return inner_join(__data, counts, by = by)



Expand Down Expand Up @@ -1323,7 +1409,8 @@ def _convert_nested_entry(x):


# TODO: will need to use multiple dispatch
@singledispatch2(pd.DataFrame)
@singledispatch2((pd.DataFrame, DataFrameGroupBy))
@_bounce_groupby
def join(left, right, on = None, how = None, *args, by = None, **kwargs):
"""Join two tables together, by matching on specified columns.
Expand Down Expand Up @@ -1428,6 +1515,8 @@ def join(left, right, on = None, how = None, *args, by = None, **kwargs):
"""

if isinstance(right, DataFrameGroupBy):
right = right.obj
if not isinstance(right, DataFrame):
raise Exception("right hand table must be a DataFrame")
if how is None:
Expand Down Expand Up @@ -1455,8 +1544,9 @@ def _join(left, right, on = None, how = None):
raise Exception("Unsupported type %s" %type(left))


@singledispatch2(pd.DataFrame)
def semi_join(left, right = None, on = None):
@singledispatch2((pd.DataFrame, DataFrameGroupBy))
@_bounce_groupby
def semi_join(left, right = None, on = None, *args, by = None):
"""Return the left table with every row that would be kept in an inner join.
Parameters
Expand Down Expand Up @@ -1492,6 +1582,10 @@ def semi_join(left, right = None, on = None):
id x
0 1 a
"""

if on is None and by is not None:
on = by

if isinstance(on, Mapping):
# coerce colnames to list, to avoid indexing with tuples
on_cols, right_on = map(list, zip(*on.items()))
Expand Down Expand Up @@ -1528,8 +1622,9 @@ def semi_join(left, right = None, on = None):
return left.loc[range_indx.isin(l_indx)]


@singledispatch2(pd.DataFrame)
def anti_join(left, right = None, on = None):
@singledispatch2((pd.DataFrame, DataFrameGroupBy))
@_bounce_groupby
def anti_join(left, right = None, on = None, *args, by = None):
"""Return the left table with every row that would *not* be kept in an inner join.
Parameters
Expand Down Expand Up @@ -1565,12 +1660,19 @@ def anti_join(left, right = None, on = None):
id x
0 1 a
"""

if on is None and by is not None:
on = by

# copied from semi_join
if isinstance(on, Mapping):
left_on, right_on = zip(*on.items())
else:
left_on = right_on = on

if isinstance(right, DataFrameGroupBy):
right = right.obj

# manually perform merge, up to getting pieces need for indexing
merger = _MergeOperation(left, right, left_on = left_on, right_on = right_on)
_, l_indx, _ = merger._get_join_info()
Expand Down Expand Up @@ -1681,7 +1783,7 @@ def top_n(__data, n, wt = None):

# Gather ======================================================================

@singledispatch2(pd.DataFrame)
@singledispatch2((pd.DataFrame, DataFrameGroupBy))
def gather(__data, key = "key", value = "value", *args, drop_na = False, convert = False):
"""Reshape table by gathering it in to long format.
Expand Down Expand Up @@ -1730,6 +1832,9 @@ def gather(__data, key = "key", value = "value", *args, drop_na = False, convert
if convert:
raise NotImplementedError("convert not yet implemented")

if isinstance(__data, DataFrameGroupBy):
__data = __data.obj

# TODO: copied from nest and select
var_list = var_create(*args)
od = var_select(__data.columns, *var_list)
Expand Down Expand Up @@ -1953,8 +2058,6 @@ def complete(__data, *args, fill = None):

# Separate/Unit/Extract ============================================================

import warnings

@singledispatch2(pd.DataFrame)
def separate(__data, col, into, sep = r"[^a-zA-Z0-9]",
remove = True, convert = False,
Expand Down
Loading

0 comments on commit 7d64101

Please sign in to comment.