diff --git a/pytest.ini b/pytest.ini index df3eb518..2d89728d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,4 @@ [pytest] addopts = --doctest-modules +markers = + skip_backend diff --git a/siuba/experimental/pd_groups/dialect.py b/siuba/experimental/pd_groups/dialect.py index 5dd5558d..02e22229 100644 --- a/siuba/experimental/pd_groups/dialect.py +++ b/siuba/experimental/pd_groups/dialect.py @@ -1,5 +1,5 @@ from siuba.spec.series import spec -from siuba.siu import CallTreeLocal +from siuba.siu import CallTreeLocal, FunctionLookupError from siuba.experimental.pd_groups.translate import SeriesGroupBy, GroupByAgg, GROUP_METHODS @@ -12,18 +12,25 @@ #if entry['result']['type']: continue kind = entry['action'].get('kind') or entry['action'].get('status') key = (kind.title(), entry['action']['data_arity']) - meth = GROUP_METHODS[key] # add properties like df.dtype, so we know they are method calls if entry['is_property'] and not entry['accessor']: call_props.add(name) - out[name] = meth( + + meth = GROUP_METHODS[key]( name = name.split('.')[-1], is_property = entry['is_property'], accessor = entry['accessor'] ) + # TODO: returning this exception class from group methods is weird, but I + # think also used in tests + if meth is NotImplementedError: + continue + + out[name] = meth + call_listener = CallTreeLocal( out, call_sub_attr = ('str', 'dt', 'cat', 'sparse'), @@ -39,13 +46,27 @@ from siuba.siu import Call from siuba.dply.verbs import mutate, filter, summarize, singledispatch2, DataFrameGroupBy, _regroup from pandas.core.dtypes.inference import is_scalar +import warnings + +def fallback_warning(expr, reason): + warnings.warn( + "The expression below cannot be executed quickly. " + "Using the slow (but general) pandas apply method." + "\n\nExpression: {}\nReason: {}" + .format(expr, reason) + ) + def grouped_eval(__data, expr, require_agg = False): if is_scalar(expr): return expr if isinstance(expr, Call): - call = call_listener.enter(expr) + try: + call = call_listener.enter(expr) + except FunctionLookupError as e: + fallback_warning(expr, str(e)) + call = expr # grouped_res = call(__data) @@ -75,13 +96,40 @@ def grouped_eval(__data, expr, require_agg = False): # Fast mutate ---- +def _transform_args(args): + out = [] + for expr in args: + if is_scalar(expr): + out.append(expr) + elif isinstance(expr, Call): + try: + call = call_listener.enter(expr) + out.append(call) + except FunctionLookupError as e: + fallback_warning(expr, str(e)) + return None + elif callable(expr): + return None + + return out + @singledispatch2(DataFrameGroupBy) def fast_mutate(__data, **kwargs): """Warning: this function is experimental""" + + # transform call trees, potentially bail out to slow method -------- + new_vals = _transform_args(kwargs.values()) + + if new_vals is None: + return mutate(__data, **kwargs) + + + # perform fast method ---- out = __data.obj.copy() groupings = __data.grouper.groupings - for name, expr in kwargs.items(): + + for name, expr in zip(kwargs, new_vals): res = grouped_eval(__data, expr) out[name] = res @@ -102,7 +150,14 @@ def _fast_mutate_default(__data, **kwargs): @singledispatch2(DataFrameGroupBy) def fast_filter(__data, *args): """Warning: this function is experimental""" - import pandas as pd + + # transform call trees, potentially bail out to slow method -------- + new_vals = _transform_args(args) + + if new_vals is None: + return filter(__data, *args) + + # perform fast method ---- out = [] groupings = __data.grouper.groupings @@ -110,7 +165,7 @@ def fast_filter(__data, *args): res = grouped_eval(__data, expr) out.append(res) - filter_df = filter.registry[pd.DataFrame] + filter_df = filter.registry[__data.obj.__class__] df_result = filter_df(__data.obj, *out) @@ -133,6 +188,14 @@ def _fast_filter_default(__data, *args, **kwargs): @singledispatch2(DataFrameGroupBy) def fast_summarize(__data, **kwargs): """Warning: this function is experimental""" + + # transform call trees, potentially bail out to slow method -------- + new_vals = _transform_args(kwargs.values()) + + if new_vals is None: + return summarize(__data, **kwargs) + + # perform fast method ---- groupings = __data.grouper.groupings # TODO: better way of getting this frame? diff --git a/siuba/experimental/pd_groups/test_pd_groups.py b/siuba/experimental/pd_groups/test_pd_groups.py index b06af3db..f27bd1e3 100644 --- a/siuba/experimental/pd_groups/test_pd_groups.py +++ b/siuba/experimental/pd_groups/test_pd_groups.py @@ -92,10 +92,14 @@ def test_agg_groupby_broadcasted_equal_to_transform(f_op, f_dst): # Test user-defined functions ================================================= -from .dialect import fast_mutate, fast_summarize, fast_filter +from .dialect import fast_mutate, fast_summarize, fast_filter, _transform_args from siuba.siu import symbolic_dispatch, _, FunctionLookupError from typing import Any +def test_transform_args(): + pass + + def test_fast_grouped_custom_user_funcs(): @symbolic_dispatch def f(x): @@ -124,10 +128,17 @@ def f(x) -> Any: def test_fast_grouped_custom_user_func_fail(): @symbolic_dispatch def f(x): + return x.mean() + + @f.register(GroupByAgg) + def _f_gser(x): + # note, no return annotation, so translator will raise an error return GroupByAgg.from_result(x.mean(), x) + gdf = data_default.groupby('g') - with pytest.raises(FunctionLookupError): + + with pytest.warns(UserWarning): g_out = fast_mutate(gdf, result1 = f(_.x), result2 = _.x.mean() + 10) @@ -157,3 +168,30 @@ def test_fast_methods_constant(): ) +def test_fast_methods_lambda(): + # testing ways to do operations via slower apply route + + gdf = data_default.groupby('g') + + # mutate ---- + out = fast_mutate(gdf, y = lambda d: len(d['x'])) + assert_frame_equal( + gdf.obj.assign(y = gdf['x'].transform('size')), + out.obj + ) + + # summarize ---- + out = fast_summarize(gdf, y = lambda d: len(d['x'])) + + agg_frame = gdf.grouper.result_index.to_frame() + assert_frame_equal( + agg_frame.assign(y = gdf['x'].agg('size')).reset_index(drop = True), + out + ) + + # filter ---- + out = fast_filter(gdf, lambda d: d['x'] > d['x'].values.min()) + assert_frame_equal( + gdf.obj[gdf.obj['x'] > gdf['x'].transform('min')], + out.obj + ) diff --git a/siuba/spec/series.yml b/siuba/spec/series.yml index 03a780dd..cb4979b3 100644 --- a/siuba/spec/series.yml +++ b/siuba/spec/series.yml @@ -284,13 +284,13 @@ add_prefix: status: wontdo backends: {} category: reindexing - example: _.add_prefix() + example: _.add_prefix('pre_') add_suffix: action: status: wontdo backends: {} category: reindexing - example: _.add_suffix() + example: _.add_suffix('_suff') agg: action: status: wontdo @@ -330,7 +330,7 @@ append: status: todo backends: {} category: combining - example: _.append() + example: _.append(_) priority: 1 apply: action: diff --git a/siuba/tests/test_dply_series_methods.py b/siuba/tests/test_dply_series_methods.py index 4d4bdca0..7dc75382 100644 --- a/siuba/tests/test_dply_series_methods.py +++ b/siuba/tests/test_dply_series_methods.py @@ -192,8 +192,13 @@ def test_pandas_grouped_frame_fast_not_implemented(notimpl_entry): if notimpl_entry['action']['status'] in ["todo", "maydo", "wontdo"] and notimpl_entry["is_property"]: pytest.xfail() - with pytest.raises(NotImplementedError): - res = fast_mutate(gdf, result = call_expr) + with pytest.warns(UserWarning): + try: + # not implemented functions are punted to apply, and + # not guaranteed to work (e.g. many lengthen arrays, etc..) + res = fast_mutate(gdf, result = call_expr) + except: + pass