Skip to content

Commit

Permalink
Merge pull request #268 from machow/feat-pdgroups-fallback
Browse files Browse the repository at this point in the history
Feat pdgroups fallback
  • Loading branch information
machow authored Aug 19, 2020
2 parents 03978f3 + c8e3d9d commit aa359a3
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 14 deletions.
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
[pytest]
addopts = --doctest-modules
markers =
skip_backend
77 changes: 70 additions & 7 deletions siuba/experimental/pd_groups/dialect.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from siuba.spec.series import spec
from siuba.siu import CallTreeLocal
from siuba.siu import CallTreeLocal, FunctionLookupError

from siuba.experimental.pd_groups.translate import SeriesGroupBy, GroupByAgg, GROUP_METHODS

Expand All @@ -12,18 +12,25 @@
#if entry['result']['type']: continue
kind = entry['action'].get('kind') or entry['action'].get('status')
key = (kind.title(), entry['action']['data_arity'])
meth = GROUP_METHODS[key]

# add properties like df.dtype, so we know they are method calls
if entry['is_property'] and not entry['accessor']:
call_props.add(name)

out[name] = meth(

meth = GROUP_METHODS[key](
name = name.split('.')[-1],
is_property = entry['is_property'],
accessor = entry['accessor']
)

# TODO: returning this exception class from group methods is weird, but I
# think also used in tests
if meth is NotImplementedError:
continue

out[name] = meth

call_listener = CallTreeLocal(
out,
call_sub_attr = ('str', 'dt', 'cat', 'sparse'),
Expand All @@ -39,13 +46,27 @@
from siuba.siu import Call
from siuba.dply.verbs import mutate, filter, summarize, singledispatch2, DataFrameGroupBy, _regroup
from pandas.core.dtypes.inference import is_scalar
import warnings

def fallback_warning(expr, reason):
warnings.warn(
"The expression below cannot be executed quickly. "
"Using the slow (but general) pandas apply method."
"\n\nExpression: {}\nReason: {}"
.format(expr, reason)
)


def grouped_eval(__data, expr, require_agg = False):
if is_scalar(expr):
return expr

if isinstance(expr, Call):
call = call_listener.enter(expr)
try:
call = call_listener.enter(expr)
except FunctionLookupError as e:
fallback_warning(expr, str(e))
call = expr

#
grouped_res = call(__data)
Expand Down Expand Up @@ -75,13 +96,40 @@ def grouped_eval(__data, expr, require_agg = False):

# Fast mutate ----

def _transform_args(args):
out = []
for expr in args:
if is_scalar(expr):
out.append(expr)
elif isinstance(expr, Call):
try:
call = call_listener.enter(expr)
out.append(call)
except FunctionLookupError as e:
fallback_warning(expr, str(e))
return None
elif callable(expr):
return None

return out

@singledispatch2(DataFrameGroupBy)
def fast_mutate(__data, **kwargs):
"""Warning: this function is experimental"""

# transform call trees, potentially bail out to slow method --------
new_vals = _transform_args(kwargs.values())

if new_vals is None:
return mutate(__data, **kwargs)


# perform fast method ----
out = __data.obj.copy()
groupings = __data.grouper.groupings

for name, expr in kwargs.items():

for name, expr in zip(kwargs, new_vals):
res = grouped_eval(__data, expr)
out[name] = res

Expand All @@ -102,15 +150,22 @@ def _fast_mutate_default(__data, **kwargs):
@singledispatch2(DataFrameGroupBy)
def fast_filter(__data, *args):
"""Warning: this function is experimental"""
import pandas as pd

# transform call trees, potentially bail out to slow method --------
new_vals = _transform_args(args)

if new_vals is None:
return filter(__data, *args)

# perform fast method ----
out = []
groupings = __data.grouper.groupings

for expr in args:
res = grouped_eval(__data, expr)
out.append(res)

filter_df = filter.registry[pd.DataFrame]
filter_df = filter.registry[__data.obj.__class__]

df_result = filter_df(__data.obj, *out)

Expand All @@ -133,6 +188,14 @@ def _fast_filter_default(__data, *args, **kwargs):
@singledispatch2(DataFrameGroupBy)
def fast_summarize(__data, **kwargs):
"""Warning: this function is experimental"""

# transform call trees, potentially bail out to slow method --------
new_vals = _transform_args(kwargs.values())

if new_vals is None:
return summarize(__data, **kwargs)

# perform fast method ----
groupings = __data.grouper.groupings

# TODO: better way of getting this frame?
Expand Down
42 changes: 40 additions & 2 deletions siuba/experimental/pd_groups/test_pd_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,14 @@ def test_agg_groupby_broadcasted_equal_to_transform(f_op, f_dst):

# Test user-defined functions =================================================

from .dialect import fast_mutate, fast_summarize, fast_filter
from .dialect import fast_mutate, fast_summarize, fast_filter, _transform_args
from siuba.siu import symbolic_dispatch, _, FunctionLookupError
from typing import Any

def test_transform_args():
pass


def test_fast_grouped_custom_user_funcs():
@symbolic_dispatch
def f(x):
Expand Down Expand Up @@ -124,10 +128,17 @@ def f(x) -> Any:
def test_fast_grouped_custom_user_func_fail():
@symbolic_dispatch
def f(x):
return x.mean()

@f.register(GroupByAgg)
def _f_gser(x):
# note, no return annotation, so translator will raise an error
return GroupByAgg.from_result(x.mean(), x)


gdf = data_default.groupby('g')
with pytest.raises(FunctionLookupError):

with pytest.warns(UserWarning):
g_out = fast_mutate(gdf, result1 = f(_.x), result2 = _.x.mean() + 10)


Expand Down Expand Up @@ -157,3 +168,30 @@ def test_fast_methods_constant():
)


def test_fast_methods_lambda():
# testing ways to do operations via slower apply route

gdf = data_default.groupby('g')

# mutate ----
out = fast_mutate(gdf, y = lambda d: len(d['x']))
assert_frame_equal(
gdf.obj.assign(y = gdf['x'].transform('size')),
out.obj
)

# summarize ----
out = fast_summarize(gdf, y = lambda d: len(d['x']))

agg_frame = gdf.grouper.result_index.to_frame()
assert_frame_equal(
agg_frame.assign(y = gdf['x'].agg('size')).reset_index(drop = True),
out
)

# filter ----
out = fast_filter(gdf, lambda d: d['x'] > d['x'].values.min())
assert_frame_equal(
gdf.obj[gdf.obj['x'] > gdf['x'].transform('min')],
out.obj
)
6 changes: 3 additions & 3 deletions siuba/spec/series.yml
Original file line number Diff line number Diff line change
Expand Up @@ -284,13 +284,13 @@ add_prefix:
status: wontdo
backends: {}
category: reindexing
example: _.add_prefix()
example: _.add_prefix('pre_')
add_suffix:
action:
status: wontdo
backends: {}
category: reindexing
example: _.add_suffix()
example: _.add_suffix('_suff')
agg:
action:
status: wontdo
Expand Down Expand Up @@ -330,7 +330,7 @@ append:
status: todo
backends: {}
category: combining
example: _.append()
example: _.append(_)
priority: 1
apply:
action:
Expand Down
9 changes: 7 additions & 2 deletions siuba/tests/test_dply_series_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,13 @@ def test_pandas_grouped_frame_fast_not_implemented(notimpl_entry):
if notimpl_entry['action']['status'] in ["todo", "maydo", "wontdo"] and notimpl_entry["is_property"]:
pytest.xfail()

with pytest.raises(NotImplementedError):
res = fast_mutate(gdf, result = call_expr)
with pytest.warns(UserWarning):
try:
# not implemented functions are punted to apply, and
# not guaranteed to work (e.g. many lengthen arrays, etc..)
res = fast_mutate(gdf, result = call_expr)
except:
pass



Expand Down

0 comments on commit aa359a3

Please sign in to comment.