Skip to content

Commit

Permalink
fix(pandas): ensure verbs accept DataFrameGroupBy
Browse files Browse the repository at this point in the history
  • Loading branch information
machow committed Sep 21, 2022
1 parent 73f9a2f commit c6ef607
Showing 1 changed file with 34 additions and 8 deletions.
42 changes: 34 additions & 8 deletions siuba/dply/verbs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from functools import singledispatch
from functools import singledispatch, wraps
from pandas import DataFrame
import pandas as pd
import numpy as np
Expand Down Expand Up @@ -81,6 +81,21 @@ def _repr_grouped_df_console_(self):
return "(grouped data frame)\n" + repr(self.obj)


def _bounce_groupby(f):
@wraps(f)
def wrapper(__data: "pd.DataFrame | DataFrameGroupBy", *args, **kwargs):
if isinstance(__data, pd.DataFrame):
return res(__data, *args, **kwargs)

groupings = __data.grouper.groupings
group_cols = [ping.name for ping in groupings]

res = f(__data.obj, *args, **kwargs)

return res.groupby(group_cols)

return wrapper


def _regroup(df):
# try to regroup after an apply, when user kept index (e.g. group_keys = True)
Expand Down Expand Up @@ -1101,7 +1116,7 @@ def count(__data, *args, wt = None, sort = False, **kwargs):
return counts


@singledispatch2(pd.DataFrame)
@singledispatch2((pd.DataFrame, DataFrameGroupBy))
def add_count(__data, *args, wt = None, sort = False, **kwargs):
"""Add a column that is the number of observations for each grouping of data.
Expand Down Expand Up @@ -1152,8 +1167,8 @@ def add_count(__data, *args, wt = None, sort = False, **kwargs):
"""
counts = count(__data, *args, wt = wt, sort = sort, **kwargs)

on = list(counts.columns)[:-1]
return __data.merge(counts, on = on)
by = list(counts.columns)[:-1]
return inner_join(__data, counts, by = by)



Expand Down Expand Up @@ -1323,7 +1338,8 @@ def _convert_nested_entry(x):


# TODO: will need to use multiple dispatch
@singledispatch2(pd.DataFrame)
@singledispatch2((pd.DataFrame, DataFrameGroupBy))
@_bounce_groupby
def join(left, right, on = None, how = None, *args, by = None, **kwargs):
"""Join two tables together, by matching on specified columns.
Expand Down Expand Up @@ -1428,6 +1444,8 @@ def join(left, right, on = None, how = None, *args, by = None, **kwargs):
"""

if isinstance(right DataFrameGroupBy):
right = right.obj
if not isinstance(right, DataFrame):
raise Exception("right hand table must be a DataFrame")
if how is None:
Expand Down Expand Up @@ -1455,7 +1473,8 @@ def _join(left, right, on = None, how = None):
raise Exception("Unsupported type %s" %type(left))


@singledispatch2(pd.DataFrame)
@singledispatch2((pd.DataFrame, DataFrameGroupBy))
@_bounce_groupby
def semi_join(left, right = None, on = None):
"""Return the left table with every row that would be kept in an inner join.
Expand Down Expand Up @@ -1528,7 +1547,8 @@ def semi_join(left, right = None, on = None):
return left.loc[range_indx.isin(l_indx)]


@singledispatch2(pd.DataFrame)
@singledispatch2((pd.DataFrame, DataFrameGroupBy))
@_bounce_groupby
def anti_join(left, right = None, on = None):
"""Return the left table with every row that would *not* be kept in an inner join.
Expand Down Expand Up @@ -1571,6 +1591,9 @@ def anti_join(left, right = None, on = None):
else:
left_on = right_on = on

if isinstance(right, DataFrameGroupBy):
right = right.obj

# manually perform merge, up to getting pieces need for indexing
merger = _MergeOperation(left, right, left_on = left_on, right_on = right_on)
_, l_indx, _ = merger._get_join_info()
Expand Down Expand Up @@ -1681,7 +1704,7 @@ def top_n(__data, n, wt = None):

# Gather ======================================================================

@singledispatch2(pd.DataFrame)
@singledispatch2((pd.DataFrame, DataFrameGroupBy))
def gather(__data, key = "key", value = "value", *args, drop_na = False, convert = False):
"""Reshape table by gathering it in to long format.
Expand Down Expand Up @@ -1730,6 +1753,9 @@ def gather(__data, key = "key", value = "value", *args, drop_na = False, convert
if convert:
raise NotImplementedError("convert not yet implemented")

if isinstance(__data, DataFrameGroupBy):
__data = __data.obj

# TODO: copied from nest and select
var_list = var_create(*args)
od = var_select(__data.columns, *var_list)
Expand Down

0 comments on commit c6ef607

Please sign in to comment.