Skip to content

Commit

Permalink
Merge pull request #160 from machow/feat-custom-ops
Browse files Browse the repository at this point in the history
Support user defined functions with fast group by
  • Loading branch information
machow authored Nov 12, 2019
2 parents 7a4092e + e10a214 commit dba8be7
Show file tree
Hide file tree
Showing 9 changed files with 442 additions and 58 deletions.
137 changes: 106 additions & 31 deletions siuba/dply/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,31 @@
import numpy as np
from functools import singledispatch
from siuba.siu import symbolic_dispatch
from pandas.core.groupby import SeriesGroupBy, GroupBy
from pandas.core.frame import NDFrame
from pandas import Series

from siuba.experimental.pd_groups.groupby import GroupByAgg, _regroup
from siuba.experimental.pd_groups.translate import method_agg_op

# Utils =======================================================================

def _expand_bool(x, f):
return x.expanding().apply(f, raw = True).astype(bool)

@symbolic_dispatch
def alias_series_agg(name):
method = method_agg_op(name, is_property = False, accessor = False)

def decorator(dispatcher):
dispatcher.register(SeriesGroupBy, method)
return dispatcher

return decorator


# Single dispatch functions ===================================================

@symbolic_dispatch(cls = Series)
def cumall(x):
"""Return a same-length array. For each entry, indicates whether that entry and all previous are True-like.
Expand All @@ -22,7 +41,7 @@ def cumall(x):
return _expand_bool(x, np.all)


@symbolic_dispatch
@symbolic_dispatch(cls = Series)
def cumany(x):
"""Return a same-length array. For each entry, indicates whether that entry or any previous are True-like.
Expand All @@ -37,18 +56,29 @@ def cumany(x):
return _expand_bool(x, np.any)


@symbolic_dispatch
@symbolic_dispatch(cls = Series)
def cummean(x):
"""Return a same-length array, containing the cumulative mean."""
return x.expanding().mean()

@symbolic_dispatch

@cummean.register(SeriesGroupBy)
def _cummean_grouped(x):
grouper = x.grouper
n_entries = x.obj.notna().groupby(grouper).cumsum()

res = x.cumsum() / n_entries

return res.groupby(grouper)


@symbolic_dispatch(cls = Series)
def desc(x):
"""Return array sorted in descending order."""
return x.sort_values(ascending = False).reset_index(drop = True)


@symbolic_dispatch
@symbolic_dispatch(cls = Series)
def dense_rank(x):
"""Return the dense rank.
Expand All @@ -69,21 +99,21 @@ def dense_rank(x):
return x.rank(method = "dense")


@symbolic_dispatch
@symbolic_dispatch(cls = Series)
def percent_rank(x):
"""TODO: Not Implemented"""
NotImplementedError("PRs welcome")


@symbolic_dispatch
@symbolic_dispatch(cls = Series)
def min_rank(x):
"""Return the min rank. See pd.Series.rank for details.
"""
return x.rank(method = "min")


@symbolic_dispatch
@symbolic_dispatch(cls = Series)
def cume_dist(x):
"""Return the cumulative distribution corresponding to each value in x.
Expand All @@ -93,7 +123,9 @@ def cume_dist(x):
return x.rank(method = "max") / x.count()


@symbolic_dispatch
# row_number ------------------------------------------------------------------

@symbolic_dispatch(cls = NDFrame)
def row_number(x):
"""Return the row number (position) for each value in x, beginning with 1.
Expand All @@ -116,16 +148,31 @@ def row_number(x):
if isinstance(x, pd.Series):
return x._constructor(arr, pd.RangeIndex(n), fastpath = True)

return arr
return pd.Series(arr)


@symbolic_dispatch
@row_number.register(GroupBy)
def _row_number_grouped(g: GroupBy) -> GroupBy:
out = np.ones(len(g.obj), dtype = int)

indices = g.grouper.indices
for g_key, inds in indices.items():
out[inds] = np.arange(1, len(inds) + 1, dtype = int)

return _regroup(out, g)


# ntile -----------------------------------------------------------------------

@symbolic_dispatch(cls = Series)
def ntile(x, n):
"""TODO: Not Implemented"""
NotImplementedError("ntile not implemented")


@symbolic_dispatch
# between ---------------------------------------------------------------------

@symbolic_dispatch(cls = Series)
def between(x, left, right):
"""Return whether a value is between left and right (including either side).
Expand All @@ -144,13 +191,17 @@ def between(x, left, right):
return x.between(left, right)


@symbolic_dispatch
# coalesce --------------------------------------------------------------------

@symbolic_dispatch(cls = Series)
def coalesce(*args):
"""TODO: Not Implemented"""
NotImplementedError("coalesce not implemented")


@symbolic_dispatch
# lead ------------------------------------------------------------------------

@symbolic_dispatch(cls = Series)
def lead(x, n = 1, default = None):
"""Return an array with each value replaced by the next (or further forward) value in the array.
Expand All @@ -167,21 +218,27 @@ def lead(x, n = 1, default = None):
dtype: float64
>>> lead(pd.Series([1,2,3]), n=1, default = 99)
0 2.0
1 3.0
2 99.0
dtype: float64
0 2
1 3
2 99
dtype: int64
"""
res = x.shift(-1*n)

if default is not None:
res.iloc[-n:] = default
res = x.shift(-1*n, fill_value = default)

return res


@symbolic_dispatch
@lead.register(SeriesGroupBy)
def _lead_grouped(x, n = 1, default = None):
res = x.shift(-1*n, fill_value = default)

return _regroup(res, x)


# lag -------------------------------------------------------------------------

@symbolic_dispatch(cls = Series)
def lag(x, n = 1, default = None):
"""Return an array with each value replaced by the previous (or further backward) value in the array.
Expand Down Expand Up @@ -213,7 +270,15 @@ def lag(x, n = 1, default = None):
return res


@symbolic_dispatch
@lag.register(SeriesGroupBy)
def _lag_grouped(x, n = 1, default = None):
res = x.shift(n, fill_value = default)

return _regroup(res, x)

# n ---------------------------------------------------------------------------

@symbolic_dispatch(cls = NDFrame)
def n(x):
"""Return the total number of elements in the array (or rows in a DataFrame).
Expand All @@ -233,7 +298,15 @@ def n(x):
return len(x)


@symbolic_dispatch
@n.register(GroupBy)
def _n_grouped(x: GroupBy) -> GroupByAgg:
return GroupByAgg.from_result(x.size(), x)


# n_distinct ------------------------------------------------------------------

@alias_series_agg('nunique')
@symbolic_dispatch(cls = Series)
def n_distinct(x):
"""Return the total number of distinct (i.e. unique) elements in an array.
Expand All @@ -242,10 +315,12 @@ def n_distinct(x):
2
"""
return len(x.unique())
return x.nunique()


# na_if -----------------------------------------------------------------------

@symbolic_dispatch
@symbolic_dispatch(cls = Series)
def na_if(x, y):
"""Return a array like x, but with values in y replaced by NAs.
Expand All @@ -265,25 +340,25 @@ def na_if(x, y):
return tmp_x


@symbolic_dispatch
@symbolic_dispatch(cls = Series)
def near(x):
"""TODO: Not Implemented"""
NotImplementedError("near not implemented")


@symbolic_dispatch
@symbolic_dispatch(cls = Series)
def nth(x):
"""TODO: Not Implemented"""
NotImplementedError("nth not implemented")


@symbolic_dispatch
@symbolic_dispatch(cls = Series)
def first(x):
"""TODO: Not Implemented"""
NotImplementedError("first not implemented")


@symbolic_dispatch
@symbolic_dispatch(cls = Series)
def last(x):
"""TODO: Not Implemented"""
NotImplementedError("last not implemented")
4 changes: 3 additions & 1 deletion siuba/experimental/pd_groups/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
call_listener = CallTreeLocal(
out,
call_sub_attr = ('str', 'dt', 'cat'),
chain_sub_attr = True
chain_sub_attr = True,
dispatch_cls = GroupByAgg,
result_cls = SeriesGroupBy
)


Expand Down
41 changes: 41 additions & 0 deletions siuba/experimental/pd_groups/test_pd_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,44 @@ def test_agg_groupby_broadcasted_equal_to_transform(f_op, f_dst):
assert_series_equal(broadcasted, dst, check_names = False)


# Test user-defined functions =================================================

from .dialect import fast_mutate
from siuba.siu import symbolic_dispatch, _, FunctionLookupError
from typing import Any

def test_fast_grouped_custom_user_funcs():
@symbolic_dispatch
def f(x):
return x.mean()

@f.register(SeriesGroupBy)
def _f_grouped(x) -> GroupByAgg:
return GroupByAgg.from_result(x.mean() + 10, x)

gdf = data_default.groupby('g')
g_out = fast_mutate(gdf, result1 = f(_.x), result2 = _.x.mean() + 10)
out = g_out.obj
assert (out.result1 == out.result2).all()


def test_fast_grouped_custom_user_func_default():
@symbolic_dispatch
def f(x) -> Any:
return GroupByAgg.from_result(x.mean() + 10, x)

gdf = data_default.groupby('g')
g_out = fast_mutate(gdf, result1 = f(_.x), result2 = _.x.mean() + 10)
out = g_out.obj
assert (out.result1 == out.result2).all()

def test_fast_grouped_custom_user_func_fail():
@symbolic_dispatch
def f(x):
return GroupByAgg.from_result(x.mean(), x)

gdf = data_default.groupby('g')
with pytest.raises(FunctionLookupError):
g_out = fast_mutate(gdf, result1 = f(_.x), result2 = _.x.mean() + 10)


22 changes: 20 additions & 2 deletions siuba/experimental/pd_groups/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,30 @@ def f(x, y):
f.__qualname__ = name
return f

def method_win_op(name, is_property, accessor):
def f(__ser, *args, **kwargs):
if not isinstance(__ser, SeriesGroupBy):
raise TypeError("All methods must operate on a grouped Series objects")

if accessor:
method = getattr(getattr(__ser, accessor), name)
else:
method = getattr(__ser, name)

res = method(*args, **kwargs) if not is_property else method
return _regroup(res, __ser)

f.__name__ = name
f.__qualname__ = name
return f


GROUP_METHODS = {
("Elwise", 1): method_el_op,
("Elwise", 2): method_el_op2,
("Agg", 1): method_agg_op,
("Window", 1): not_implemented,
("Window", 2): not_implemented,
("Window", 1): method_win_op,
("Window", 2): method_win_op,
("Singleton", 1): not_implemented
}

Expand Down
Loading

0 comments on commit dba8be7

Please sign in to comment.