From 26a4e07ef85b13535bc444505790c574328905a8 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sun, 3 Nov 2019 23:21:50 -0500 Subject: [PATCH 1/7] working draft implementation of UDF translations --- siuba/dply/vector.py | 6 ++ siuba/experimental/pd_groups/dialect.py | 3 +- .../experimental/pd_groups/test_pd_groups.py | 20 +++++ siuba/siu.py | 64 ++++++++++++--- siuba/tests/test_siu.py | 79 ++++++++++++++++++- 5 files changed, 158 insertions(+), 14 deletions(-) diff --git a/siuba/dply/vector.py b/siuba/dply/vector.py index 830cc098..5139e687 100644 --- a/siuba/dply/vector.py +++ b/siuba/dply/vector.py @@ -2,6 +2,7 @@ import numpy as np from functools import singledispatch from siuba.siu import symbolic_dispatch +from siuba.experimental.pd_groups.groupby import DataFrameGroupBy, SeriesGroupBy, GroupByAgg def _expand_bool(x, f): @@ -232,6 +233,11 @@ def n(x): return len(x) +@n.register(SeriesGroupBy) +@n.register(DataFrameGroupBy) +def _n_grouped(x): + return GroupByAgg.from_result(x.size(), x) + @symbolic_dispatch def n_distinct(x): diff --git a/siuba/experimental/pd_groups/dialect.py b/siuba/experimental/pd_groups/dialect.py index 3cf478f6..594a49af 100644 --- a/siuba/experimental/pd_groups/dialect.py +++ b/siuba/experimental/pd_groups/dialect.py @@ -21,7 +21,8 @@ call_listener = CallTreeLocal( out, call_sub_attr = ('str', 'dt', 'cat'), - chain_sub_attr = True + chain_sub_attr = True, + dispatch_cls = SeriesGroupBy ) diff --git a/siuba/experimental/pd_groups/test_pd_groups.py b/siuba/experimental/pd_groups/test_pd_groups.py index 3104732b..9b78cb3a 100644 --- a/siuba/experimental/pd_groups/test_pd_groups.py +++ b/siuba/experimental/pd_groups/test_pd_groups.py @@ -90,3 +90,23 @@ def test_agg_groupby_broadcasted_equal_to_transform(f_op, f_dst): assert_series_equal(broadcasted, dst, check_names = False) +# Test user-defined functions ================================================= + +from .dialect import fast_mutate +from siuba.siu import symbolic_dispatch, _ + +def test_fast_grouped_custom_user_funcs(): + @symbolic_dispatch + def f(x): + return x.mean() + + @f.register(SeriesGroupBy) + def _f_grouped(x): + return GroupByAgg.from_result(x.mean() + 10, x) + + gdf = data_default.groupby('g') + g_out = fast_mutate(gdf, result1 = f(_.x), result2 = _.x.mean() + 10) + out = g_out.obj + assert (out.result1 == out.result2).all() + + diff --git a/siuba/siu.py b/siuba/siu.py index 05cea808..c3c97c02 100644 --- a/siuba/siu.py +++ b/siuba/siu.py @@ -161,8 +161,7 @@ def __repr__(self): op_repr = "." fmt = "({args[0]}{func}{args[1]})" else: - op_repr, rest = self.args[0], self.args[1:] - arg_str = map(repr, rest) + op_repr, *arg_str = map(repr, self.args) kwarg_str = (str(k) + " = " + repr(v) for k,v in self.kwargs.items()) combined_arg_str = ",".join(itertools.chain(arg_str, kwarg_str)) @@ -326,12 +325,21 @@ def map_subcalls(self, f): def __call__(self, x): return self.args[1] +# Special kinds of call arguments ---- +# These functions insure that when using siu expressions generated by _, +# that call.args[0] is always another call. This allows them to trivially +# respond to calling a siu expression, map_subcalls, etc.. +# +# In the future, could make a parent class for Call, with a restricted +# set of behavior similar to theirs. +# +# TODO: validate that call.args[0] is a Call in tree visitors? class MetaArg(Call): def __init__(self, func, *args, **kwargs): self.func = "_" - self.args = args - self.kwargs = kwargs + self.args = tuple() + self.kwargs = {} def __repr__(self): return self.func @@ -339,6 +347,22 @@ def __repr__(self): def __call__(self, x): return x +class FuncArg(Call): + def __init__(self, func, *args, **kwargs): + self.func = '__custom_func__' + + if func == '__custom_func__': + func = args[0] + + self.args = tuple([func]) + self.kwargs = {} + + def __repr__(self): + return repr(self.args[0]) + + def __call__(self, x): + return self.args[0] + # Trees and Visitors ========================================================== from .error import ShortException @@ -421,7 +445,7 @@ def get_attr_chain(node, max_n): return list(reversed(out)), crnt_node -from inspect import isclass +from inspect import isclass, isfunction class CallTreeLocal(CallListener): def __init__( @@ -429,7 +453,7 @@ def __init__( local, call_sub_attr = None, chain_sub_attr = False, - replace_calls = True + dispatch_cls = None ): """ Arguments: @@ -438,12 +462,13 @@ def __init__( methods. Eg. {'dt'} to signify in _.dt.year, year is a property call. chain_sub_attr: whether to included the attributes in the above argument, when looking up up a replacement for the property call. E.g. does local have a 'dt.year' entry. - replace_calls: whether all calls, including custom call objects should be replaced. + dispatch_cls: if custom calls are dispatchers, dispatch on this class. If none, use their name + to try and get their corresponding local function. """ self.local = local self.call_sub_attr = set(call_sub_attr or []) self.chain_sub_attr = chain_sub_attr - self.replace_calls = replace_calls + self.dispatch_cls = dispatch_cls def create_local_call(self, name, prev_obj, cls, func_args = None, func_kwargs = None): # need call attr name (arg[0].args[1]) @@ -469,6 +494,7 @@ def create_local_call(self, name, prev_obj, cls, func_args = None, func_kwargs = def enter(self, node): # if no enter metthod for operators, like __invert__, try to get from local + # TODO: want to only do this if func is the name of an infix op's method method = 'enter_' + node.func if not hasattr(self, method) and node.func in self.local: args, kwargs = node.map_subcalls(self.enter) @@ -490,6 +516,20 @@ def enter___getattr__(self, node): return self.generic_enter(node) + def enter___custom_func__(self, node): + func = node(None) + + # TODO: not robust at all, need class for singledispatch? unique attr flag? + if (hasattr(func, 'registry') + and hasattr(func, 'dispatch') + and self.dispatch_cls is not None + ): + # allow custom functions that dispatch on dispatch_cls + f_for_cls = func.registry[self.dispatch_cls] + return node.__class__(f_for_cls) + + return self.generic_enter(node) + def enter___call__(self, node): """ Overview: @@ -519,9 +559,9 @@ def enter___call__(self, node): call_name = attr_chain[-1] entered_target = self.enter_if_call(obj.args[0]) - elif node.obj_name() is not None and self.replace_calls: + elif isinstance(obj, FuncArg) and self.dispatch_cls is None: # want function(_.x) -> new_function(_.x), has form - call_name = node.obj_name() + call_name = obj.obj_name() # the first argument is basically "self" entered_target, *args = args else: @@ -632,13 +672,13 @@ def symbolic_dispatch(f): f = singledispatch(f) @f.register(Symbolic) def _dispatch_symbol(__data, *args, **kwargs): - return create_sym_call(f, __data.source, *args, **kwargs) + return create_sym_call(FuncArg(f), __data.source, *args, **kwargs) @f.register(Call) def _dispatch_call(__data, *args, **kwargs): # TODO: want to just create call, for now use hack of creating a symbolic # call and getting the source off of it... - return create_sym_call(f, __data, *args, **kwargs).source + return create_sym_call(FuncArg(f), __data, *args, **kwargs).source return f diff --git a/siuba/tests/test_siu.py b/siuba/tests/test_siu.py index 22d6d8a1..5161bac2 100644 --- a/siuba/tests/test_siu.py +++ b/siuba/tests/test_siu.py @@ -1,10 +1,52 @@ -from siuba.siu import _, strip_symbolic, CallTreeLocal, FunctionLookupError +from siuba.siu import _, strip_symbolic, FunctionLookupError, Symbolic, MetaArg import pytest def test_op_vars_slice(): assert strip_symbolic(_.a[_.b:_.c]).op_vars() == {'a', 'b', 'c'} +# Symbolic dispatch =========================================================== +from siuba.siu import symbolic_dispatch, Call, FuncArg + +def test_FuncArg(): + f = lambda x: 1 + expr = FuncArg(f) + + assert expr(None) is f + +def test_FuncArg_in_call(): + call = Call( + '__call__', + FuncArg(lambda x, y: x + y), + 1, y = 2 + ) + + assert call(None) == 3 + + + +def test_symbolic_dispatch(): + @symbolic_dispatch + def f(x, y = 2): + return x + y + + # w/ simple Call + call1 = f(strip_symbolic(_), 3) + assert isinstance(call1, Call) + assert call1(2) == 5 + + # w/ simple Symbol + sym2 = f(_, 3) + assert isinstance(sym2, Symbolic) + assert sym2(2) == 5 + + # w/ complex Call + sym3 = f(_['a'], 3) + assert sym3({'a': 2}) == 5 + + + # Call Tree Local ============================================================= +from siuba.siu import CallTreeLocal @pytest.fixture def ctl(): @@ -38,3 +80,38 @@ def test_call_tree_local_sub_attr_property_missing(ctl): with pytest.raises(FunctionLookupError): ctl.enter(strip_symbolic(_.str.f_b)) +class SomeClass: pass + +@pytest.fixture +def f_dispatch(): + @symbolic_dispatch + def f(x): + return 'default' + + @f.register(SomeClass) + def _f_some_class(x): + return 'some class' + + yield f + + +def test_call_tree_local_dispatch_cls_object(f_dispatch): + ctl = CallTreeLocal( + {'f_a': lambda self: self}, + dispatch_cls = object + ) + + call = Call("__call__", FuncArg(f_dispatch), MetaArg('_')) + new_call = ctl.enter(call) + assert new_call('na') == 'default' + + +def test_call_tree_local_dispatch_cls_subclass(f_dispatch): + ctl = CallTreeLocal( + {'f_a': lambda self: self}, + dispatch_cls = SomeClass + ) + + call = Call("__call__", FuncArg(f_dispatch), MetaArg('_')) + new_call = ctl.enter(call) + assert new_call('na') == 'some class' From f0a4556ed4fe49720f4b6312f3bb93b8183f316c Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sun, 10 Nov 2019 20:29:05 -0500 Subject: [PATCH 2/7] expand UDF translations to include a few for vector funcs --- siuba/dply/vector.py | 121 ++++++++++++++++++++++++++++++++++++------- siuba/siu.py | 37 +++++++++---- 2 files changed, 128 insertions(+), 30 deletions(-) diff --git a/siuba/dply/vector.py b/siuba/dply/vector.py index 5139e687..2cce5708 100644 --- a/siuba/dply/vector.py +++ b/siuba/dply/vector.py @@ -2,13 +2,43 @@ import numpy as np from functools import singledispatch from siuba.siu import symbolic_dispatch -from siuba.experimental.pd_groups.groupby import DataFrameGroupBy, SeriesGroupBy, GroupByAgg +from pandas.core.groupby import SeriesGroupBy, GroupBy +from pandas.core.frame import NDFrame +from pandas import Series +from siuba.experimental.pd_groups.groupby import GroupByAgg, _regroup +from siuba.experimental.pd_groups.translate import method_agg_op + +# Utils ======================================================================= def _expand_bool(x, f): return x.expanding().apply(f, raw = True).astype(bool) -@symbolic_dispatch +def group_value_splits(g, to_series = False): + indices = g.grouper.indices + for g_key, inds in indices.items(): + array = g.obj.values[inds] + if to_series: + indx = pd.RangeIndex._simple_new(range(len(array))) + yield pd.Series(array, index = indx, dtype = g.obj.dtype, fastpath = True) + + else: + yield array + + +def alias_series_agg(name): + method = method_agg_op(name, is_property = False, accessor = False) + + def decorator(dispatcher): + dispatcher.register(SeriesGroupBy, method) + return dispatcher + + return decorator + + +# Single dispatch functions =================================================== + +@symbolic_dispatch(cls = Series) def cumall(x): """Return a same-length array. For each entry, indicates whether that entry and all previous are True-like. @@ -23,7 +53,7 @@ def cumall(x): return _expand_bool(x, np.all) -@symbolic_dispatch +@symbolic_dispatch(cls = Series) def cumany(x): """Return a same-length array. For each entry, indicates whether that entry or any previous are True-like. @@ -38,18 +68,29 @@ def cumany(x): return _expand_bool(x, np.any) -@symbolic_dispatch +@symbolic_dispatch(cls = Series) def cummean(x): """Return a same-length array, containing the cumulative mean.""" return x.expanding().mean() -@symbolic_dispatch + +@cummean.register(SeriesGroupBy) +def _cummean_grouped(x): + grouper = x.grouper + n_entries = x.obj.notna().groupby(grouper).cumsum() + + res = x.cumsum() / n_entries + + return res.groupby(grouper) + + +@symbolic_dispatch(cls = Series) def desc(x): """Return array sorted in descending order.""" return x.sort_values(ascending = False).reset_index(drop = True) -@symbolic_dispatch +@symbolic_dispatch(cls = Series) def dense_rank(x): """Return the dense rank. @@ -94,6 +135,8 @@ def cume_dist(x): return x.rank(method = "max") / x.count() +# row_number ------------------------------------------------------------------ + @symbolic_dispatch def row_number(x): """Return the row number (position) for each value in x, beginning with 1. @@ -117,15 +160,30 @@ def row_number(x): if isinstance(x, pd.Series): return x._constructor(arr, pd.RangeIndex(n), fastpath = True) - return arr + return pd.Series(arr) +@row_number.register(GroupBy) +def _row_number_grouped(g: GroupBy) -> GroupBy: + out = np.ones(len(g.obj), dtype = int) + + indices = g.grouper.indices + for g_key, inds in indices.items(): + out[inds] = np.arange(1, len(inds) + 1, dtype = int) + + return _regroup(out, g) + + +# ntile ----------------------------------------------------------------------- + @symbolic_dispatch def ntile(x, n): """TODO: Not Implemented""" NotImplementedError("ntile not implemented") +# between --------------------------------------------------------------------- + @symbolic_dispatch def between(x, left, right): """Return whether a value is between left and right (including either side). @@ -145,12 +203,16 @@ def between(x, left, right): return x.between(left, right) +# coalesce -------------------------------------------------------------------- + @symbolic_dispatch def coalesce(*args): """TODO: Not Implemented""" NotImplementedError("coalesce not implemented") +# lead ------------------------------------------------------------------------ + @symbolic_dispatch def lead(x, n = 1, default = None): """Return an array with each value replaced by the next (or further forward) value in the array. @@ -168,20 +230,26 @@ def lead(x, n = 1, default = None): dtype: float64 >>> lead(pd.Series([1,2,3]), n=1, default = 99) - 0 2.0 - 1 3.0 - 2 99.0 - dtype: float64 + 0 2 + 1 3 + 2 99 + dtype: int64 """ - res = x.shift(-1*n) - - if default is not None: - res.iloc[-n:] = default + res = x.shift(-1*n, fill_value = default) return res +@lead.register(SeriesGroupBy) +def _lead_grouped(x, n = 1, default = None): + res = x.shift(-1*n, fill_value = default) + + return _regroup(res, x) + + +# lag ------------------------------------------------------------------------- + @symbolic_dispatch def lag(x, n = 1, default = None): """Return an array with each value replaced by the previous (or further backward) value in the array. @@ -214,7 +282,15 @@ def lag(x, n = 1, default = None): return res -@symbolic_dispatch +@lag.register(SeriesGroupBy) +def _lag_grouped(x, n = 1, default = None): + res = x.shift(n, fill_value = default) + + return _regroup(res, x) + +# n --------------------------------------------------------------------------- + +@symbolic_dispatch(cls = NDFrame) def n(x): """Return the total number of elements in the array (or rows in a DataFrame). @@ -233,12 +309,15 @@ def n(x): return len(x) -@n.register(SeriesGroupBy) -@n.register(DataFrameGroupBy) -def _n_grouped(x): + +@n.register(GroupBy) +def _n_grouped(x: GroupBy) -> GroupByAgg: return GroupByAgg.from_result(x.size(), x) +# n_distinct ------------------------------------------------------------------ + +@alias_series_agg('nunique') @symbolic_dispatch def n_distinct(x): """Return the total number of distinct (i.e. unique) elements in an array. @@ -248,8 +327,10 @@ def n_distinct(x): 2 """ - return len(x.unique()) + return x.nunique() + +# na_if ----------------------------------------------------------------------- @symbolic_dispatch def na_if(x, y): diff --git a/siuba/siu.py b/siuba/siu.py index c3c97c02..4692c86b 100644 --- a/siuba/siu.py +++ b/siuba/siu.py @@ -665,25 +665,42 @@ def explain(symbol): # symbolic dispatch wrapper --------------------------------------------------- -from functools import singledispatch +from functools import singledispatch, update_wrapper +import inspect + +def _dispatch_not_impl(func_name): + def f(x, *args, **kwargs): + raise TypeError("singledispatch function {func_name} not implemented for type {type}" + .format(func_name = func_name, type = type(x)) + ) + + return f + +def symbolic_dispatch(f = None, cls = object): + if f is None: + return lambda f: symbolic_dispatch(f, cls) -def symbolic_dispatch(f): # TODO: don't use singledispatch if it has already been done - f = singledispatch(f) - @f.register(Symbolic) + dispatch_func = singledispatch(f) + + if cls is not object: + dispatch_func.register(cls, f) + dispatch_func.register(object, _dispatch_not_impl(dispatch_func.__name__)) + + + @dispatch_func.register(Symbolic) def _dispatch_symbol(__data, *args, **kwargs): - return create_sym_call(FuncArg(f), __data.source, *args, **kwargs) + return create_sym_call(FuncArg(dispatch_func), __data.source, *args, **kwargs) - @f.register(Call) + @dispatch_func.register(Call) def _dispatch_call(__data, *args, **kwargs): # TODO: want to just create call, for now use hack of creating a symbolic # call and getting the source off of it... - return create_sym_call(FuncArg(f), __data, *args, **kwargs).source - - return f - + return create_sym_call(FuncArg(dispatch_func), __data, *args, **kwargs).source + return dispatch_func + # Do some gnarly method setting ----------------------------------------------- def create_binary_op(op_name): From 841cee8de3c1c1cbf4801d59aed846a0323aea79 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sun, 10 Nov 2019 20:43:24 -0500 Subject: [PATCH 3/7] enable fast window methods --- siuba/experimental/pd_groups/translate.py | 22 ++++++++++++++++++++-- siuba/tests/test_dply_series_methods.py | 4 ++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/siuba/experimental/pd_groups/translate.py b/siuba/experimental/pd_groups/translate.py index 06319b2d..1b49353f 100644 --- a/siuba/experimental/pd_groups/translate.py +++ b/siuba/experimental/pd_groups/translate.py @@ -63,12 +63,30 @@ def f(x, y): f.__qualname__ = name return f +def method_win_op(name, is_property, accessor): + def f(__ser, *args, **kwargs): + if not isinstance(__ser, SeriesGroupBy): + raise TypeError("All methods must operate on a grouped Series objects") + + if accessor: + method = getattr(getattr(__ser, accessor), name) + else: + method = getattr(__ser, name) + + res = method(*args, **kwargs) if not is_property else method + return _regroup(res, __ser) + + f.__name__ = name + f.__qualname__ = name + return f + + GROUP_METHODS = { ("Elwise", 1): method_el_op, ("Elwise", 2): method_el_op2, ("Agg", 1): method_agg_op, - ("Window", 1): not_implemented, - ("Window", 2): not_implemented, + ("Window", 1): method_win_op, + ("Window", 2): method_win_op, ("Singleton", 1): not_implemented } diff --git a/siuba/tests/test_dply_series_methods.py b/siuba/tests/test_dply_series_methods.py index 4e5a2185..27702417 100644 --- a/siuba/tests/test_dply_series_methods.py +++ b/siuba/tests/test_dply_series_methods.py @@ -13,8 +13,8 @@ def filter_on_result(spec, types): return [k for k,v in spec.items() if v['result']['type'] in types] -SPEC_IMPLEMENTED = filter_on_result(spec, {"Agg", "Elwise"}) -SPEC_NOTIMPLEMENTED = filter_on_result(spec, {"Window", "Singleton"}) +SPEC_IMPLEMENTED = filter_on_result(spec, {"Agg", "Elwise", "Window"}) +SPEC_NOTIMPLEMENTED = filter_on_result(spec, {"Singleton"}) SPEC_AGG = filter_on_result(spec, {"Agg"}) _ = Symbolic() From 306338b30123bc77ecd1680951782c60417366c2 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sun, 10 Nov 2019 20:31:05 -0500 Subject: [PATCH 4/7] remove method that's not on grouped df from spec --- siuba/spec/series.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/siuba/spec/series.py b/siuba/spec/series.py index 3640a2b0..7f15150a 100644 --- a/siuba/spec/series.py +++ b/siuba/spec/series.py @@ -178,7 +178,7 @@ class WontImplement(Result): pass 'abs': _.abs() >> Elwise(), 'all': _.all() >> Agg(), 'any': _.any() >> Agg(), - 'autocorr': _.autocorr() >> Window(), + # 'autocorr': _.autocorr() >> Window(), # TODO: doesn't exist on GDF 'between': _.between(2, 5) >> Elwise(), 'clip': _.clip(2, 5) >> Elwise(), # clip_lower # TODO: deprecated From b0dba7649eddcc5348e48e8bb1b60f4889ce0831 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 11 Nov 2019 14:05:10 -0500 Subject: [PATCH 5/7] remove unused code --- siuba/dply/vector.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/siuba/dply/vector.py b/siuba/dply/vector.py index 2cce5708..713833f5 100644 --- a/siuba/dply/vector.py +++ b/siuba/dply/vector.py @@ -14,18 +14,6 @@ def _expand_bool(x, f): return x.expanding().apply(f, raw = True).astype(bool) -def group_value_splits(g, to_series = False): - indices = g.grouper.indices - for g_key, inds in indices.items(): - array = g.obj.values[inds] - if to_series: - indx = pd.RangeIndex._simple_new(range(len(array))) - yield pd.Series(array, index = indx, dtype = g.obj.dtype, fastpath = True) - - else: - yield array - - def alias_series_agg(name): method = method_agg_op(name, is_property = False, accessor = False) From 05d40a781c8d258e44c15f6b4cd3a4946e961fd3 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 11 Nov 2019 23:09:40 -0500 Subject: [PATCH 6/7] CallTreeLocal can check return type annotations, used in fast grouped ops --- siuba/dply/vector.py | 30 +++++------ siuba/experimental/pd_groups/dialect.py | 3 +- .../experimental/pd_groups/test_pd_groups.py | 25 ++++++++- siuba/siu.py | 28 ++++++++-- siuba/tests/test_siu.py | 34 ++++++++++++ siuba/utils.py | 52 +++++++++++++++++++ 6 files changed, 151 insertions(+), 21 deletions(-) create mode 100644 siuba/utils.py diff --git a/siuba/dply/vector.py b/siuba/dply/vector.py index 713833f5..5b00e226 100644 --- a/siuba/dply/vector.py +++ b/siuba/dply/vector.py @@ -99,13 +99,13 @@ def dense_rank(x): return x.rank(method = "dense") -@symbolic_dispatch +@symbolic_dispatch(cls = Series) def percent_rank(x): """TODO: Not Implemented""" NotImplementedError("PRs welcome") -@symbolic_dispatch +@symbolic_dispatch(cls = Series) def min_rank(x): """Return the min rank. See pd.Series.rank for details. @@ -113,7 +113,7 @@ def min_rank(x): return x.rank(method = "min") -@symbolic_dispatch +@symbolic_dispatch(cls = Series) def cume_dist(x): """Return the cumulative distribution corresponding to each value in x. @@ -125,7 +125,7 @@ def cume_dist(x): # row_number ------------------------------------------------------------------ -@symbolic_dispatch +@symbolic_dispatch(cls = NDFrame) def row_number(x): """Return the row number (position) for each value in x, beginning with 1. @@ -164,7 +164,7 @@ def _row_number_grouped(g: GroupBy) -> GroupBy: # ntile ----------------------------------------------------------------------- -@symbolic_dispatch +@symbolic_dispatch(cls = Series) def ntile(x, n): """TODO: Not Implemented""" NotImplementedError("ntile not implemented") @@ -172,7 +172,7 @@ def ntile(x, n): # between --------------------------------------------------------------------- -@symbolic_dispatch +@symbolic_dispatch(cls = Series) def between(x, left, right): """Return whether a value is between left and right (including either side). @@ -193,7 +193,7 @@ def between(x, left, right): # coalesce -------------------------------------------------------------------- -@symbolic_dispatch +@symbolic_dispatch(cls = Series) def coalesce(*args): """TODO: Not Implemented""" NotImplementedError("coalesce not implemented") @@ -201,7 +201,7 @@ def coalesce(*args): # lead ------------------------------------------------------------------------ -@symbolic_dispatch +@symbolic_dispatch(cls = Series) def lead(x, n = 1, default = None): """Return an array with each value replaced by the next (or further forward) value in the array. @@ -238,7 +238,7 @@ def _lead_grouped(x, n = 1, default = None): # lag ------------------------------------------------------------------------- -@symbolic_dispatch +@symbolic_dispatch(cls = Series) def lag(x, n = 1, default = None): """Return an array with each value replaced by the previous (or further backward) value in the array. @@ -306,7 +306,7 @@ def _n_grouped(x: GroupBy) -> GroupByAgg: # n_distinct ------------------------------------------------------------------ @alias_series_agg('nunique') -@symbolic_dispatch +@symbolic_dispatch(cls = Series) def n_distinct(x): """Return the total number of distinct (i.e. unique) elements in an array. @@ -320,7 +320,7 @@ def n_distinct(x): # na_if ----------------------------------------------------------------------- -@symbolic_dispatch +@symbolic_dispatch(cls = Series) def na_if(x, y): """Return a array like x, but with values in y replaced by NAs. @@ -340,25 +340,25 @@ def na_if(x, y): return tmp_x -@symbolic_dispatch +@symbolic_dispatch(cls = Series) def near(x): """TODO: Not Implemented""" NotImplementedError("near not implemented") -@symbolic_dispatch +@symbolic_dispatch(cls = Series) def nth(x): """TODO: Not Implemented""" NotImplementedError("nth not implemented") -@symbolic_dispatch +@symbolic_dispatch(cls = Series) def first(x): """TODO: Not Implemented""" NotImplementedError("first not implemented") -@symbolic_dispatch +@symbolic_dispatch(cls = Series) def last(x): """TODO: Not Implemented""" NotImplementedError("last not implemented") diff --git a/siuba/experimental/pd_groups/dialect.py b/siuba/experimental/pd_groups/dialect.py index 594a49af..cdab18c4 100644 --- a/siuba/experimental/pd_groups/dialect.py +++ b/siuba/experimental/pd_groups/dialect.py @@ -22,7 +22,8 @@ out, call_sub_attr = ('str', 'dt', 'cat'), chain_sub_attr = True, - dispatch_cls = SeriesGroupBy + dispatch_cls = GroupByAgg, + result_cls = SeriesGroupBy ) diff --git a/siuba/experimental/pd_groups/test_pd_groups.py b/siuba/experimental/pd_groups/test_pd_groups.py index 9b78cb3a..102f8182 100644 --- a/siuba/experimental/pd_groups/test_pd_groups.py +++ b/siuba/experimental/pd_groups/test_pd_groups.py @@ -93,7 +93,8 @@ def test_agg_groupby_broadcasted_equal_to_transform(f_op, f_dst): # Test user-defined functions ================================================= from .dialect import fast_mutate -from siuba.siu import symbolic_dispatch, _ +from siuba.siu import symbolic_dispatch, _, FunctionLookupError +from typing import Any def test_fast_grouped_custom_user_funcs(): @symbolic_dispatch @@ -101,7 +102,7 @@ def f(x): return x.mean() @f.register(SeriesGroupBy) - def _f_grouped(x): + def _f_grouped(x) -> GroupByAgg: return GroupByAgg.from_result(x.mean() + 10, x) gdf = data_default.groupby('g') @@ -110,3 +111,23 @@ def _f_grouped(x): assert (out.result1 == out.result2).all() +def test_fast_grouped_custom_user_func_default(): + @symbolic_dispatch + def f(x) -> Any: + return GroupByAgg.from_result(x.mean() + 10, x) + + gdf = data_default.groupby('g') + g_out = fast_mutate(gdf, result1 = f(_.x), result2 = _.x.mean() + 10) + out = g_out.obj + assert (out.result1 == out.result2).all() + +def test_fast_grouped_custom_user_func_fail(): + @symbolic_dispatch + def f(x): + return GroupByAgg.from_result(x.mean(), x) + + gdf = data_default.groupby('g') + with pytest.raises(FunctionLookupError): + g_out = fast_mutate(gdf, result1 = f(_.x), result2 = _.x.mean() + 10) + + diff --git a/siuba/siu.py b/siuba/siu.py index 4692c86b..4e6362c5 100644 --- a/siuba/siu.py +++ b/siuba/siu.py @@ -446,6 +446,8 @@ def get_attr_chain(node, max_n): from inspect import isclass, isfunction +from typing import get_type_hints +from .utils import is_dispatch_func_subtype class CallTreeLocal(CallListener): def __init__( @@ -453,7 +455,8 @@ def __init__( local, call_sub_attr = None, chain_sub_attr = False, - dispatch_cls = None + dispatch_cls = None, + result_cls = None ): """ Arguments: @@ -464,11 +467,14 @@ def __init__( up a replacement for the property call. E.g. does local have a 'dt.year' entry. dispatch_cls: if custom calls are dispatchers, dispatch on this class. If none, use their name to try and get their corresponding local function. + result_cls: if custom calls are dispatchers, require their result annotation to be a subclass + of this class. """ self.local = local self.call_sub_attr = set(call_sub_attr or []) self.chain_sub_attr = chain_sub_attr self.dispatch_cls = dispatch_cls + self.result_cls = result_cls def create_local_call(self, name, prev_obj, cls, func_args = None, func_kwargs = None): # need call attr name (arg[0].args[1]) @@ -525,9 +531,25 @@ def enter___custom_func__(self, node): and self.dispatch_cls is not None ): # allow custom functions that dispatch on dispatch_cls - f_for_cls = func.registry[self.dispatch_cls] - return node.__class__(f_for_cls) + f_for_cls = func.dispatch(self.dispatch_cls) + if (self.result_cls is None + or is_dispatch_func_subtype(f_for_cls, self.dispatch_cls, self.result_cls) + ): + # matches return annotation type (or not required) + return node.__class__(f_for_cls) + + raise FunctionLookupError( + "External function {name} can dispatch on the class {dispatch_cls}, but " + "must also have result annotation of (sub)type {result_cls}" + .format( + name = func.__name__, + dispatch_cls = self.dispatch_cls, + result_cls = self.result_cls + ) + ) + # doesn't raise an error so we can look in locals for now + # TODO: remove behavior, once all SQL dispatch funcs moved from locals return self.generic_enter(node) def enter___call__(self, node): diff --git a/siuba/tests/test_siu.py b/siuba/tests/test_siu.py index 5161bac2..7f5b1726 100644 --- a/siuba/tests/test_siu.py +++ b/siuba/tests/test_siu.py @@ -80,6 +80,7 @@ def test_call_tree_local_sub_attr_property_missing(ctl): with pytest.raises(FunctionLookupError): ctl.enter(strip_symbolic(_.str.f_b)) +# symbolic dispatch and call tree local ---- class SomeClass: pass @pytest.fixture @@ -115,3 +116,36 @@ def test_call_tree_local_dispatch_cls_subclass(f_dispatch): call = Call("__call__", FuncArg(f_dispatch), MetaArg('_')) new_call = ctl.enter(call) assert new_call('na') == 'some class' + + +# strict symbolic dispatch and call tree local ---- + +@pytest.fixture +def f_dispatch_strict(): + @symbolic_dispatch(cls = SomeClass) + def f(x): + return 'some class' + + yield f + +def test_strict_dispatch_strict_default_fail(f_dispatch_strict): + class Other(object): pass + + obj = Other() + + with pytest.raises(TypeError): + f_dispatch_strict(obj) + +def test_call_tree_local_dispatch_fail(f_dispatch_strict): + ctl = CallTreeLocal( + {'f_a': lambda self: self}, + dispatch_cls = object + ) + + call = Call("__call__", FuncArg(f_dispatch_strict), MetaArg('_')) + + # should be the default failure dispatch for object + new_call = ctl.enter(call) + with pytest.raises(TypeError): + new_call('na') + diff --git a/siuba/utils.py b/siuba/utils.py new file mode 100644 index 00000000..28c25e85 --- /dev/null +++ b/siuba/utils.py @@ -0,0 +1,52 @@ +# TODO: move siu.py into its own folder, add this to it (w/ Call Trees) +from typing import _Any, _Union, TypeVar +import inspect + +def is_flex_subclass(x, cls): + if isinstance(x, _Any): + return True + + return issubclass(x, cls) + +def is_dispatch_func_subtype(f, input_cls, output_cls): + """Returns whether a singledispatch function is subtype of some input and result class. + + A function is a subtype if it is input contravariant, and result covariant. + + Rules for evaluating return types <= output_cls: + * Any always returns True + * Union[A, B] returns true if either A or B are covariant + * f(arg_name:TypeVar) -> TypeVar compares input_cls and output_cls + * Simple return types checked via issubclass + + Args: + input_cls - input class for first argument to function + output_cls - output class for function result + + """ + sig = inspect.signature(f) + # result annotation + res_type = sig.return_annotation + + # first parameter annotation + par0 = next(iter(sig.parameters.values())) + par_type0 = par0.annotation + + # Case 1: no annotation + if res_type is None: + return False + + # Case 2: fancy annotations: Union, generic TypeVar + if isinstance(res_type, _Union) and hasattr(res_type, '__args__'): + # passes if any unioned types are subclasses + sub_types = res_type.__args__ + return any(map(lambda x: is_flex_subclass(x, output_cls), sub_types)) + elif isinstance(res_type, TypeVar): + if res_type == par_type0: + # using a generic type variable as first arg and result + # return type must be covariant on input_cls + return issubclass(input_cls, output_cls) and res_type.__covariant__ + else: + raise TypeError("Generic type used as result, but not as first parameter") + + return is_flex_subclass(res_type, output_cls) From e10a2143914e260c50722df94c0955327f89e5fa Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 11 Nov 2019 23:37:06 -0500 Subject: [PATCH 7/7] fix typing lib compatibilty with py3.7 --- siuba/utils.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/siuba/utils.py b/siuba/utils.py index 28c25e85..8ee6f4b5 100644 --- a/siuba/utils.py +++ b/siuba/utils.py @@ -1,9 +1,15 @@ # TODO: move siu.py into its own folder, add this to it (w/ Call Trees) -from typing import _Any, _Union, TypeVar +from typing import Any, Union, TypeVar import inspect +def is_union(x): + return getattr(x, '__origin__', None) is Union + +def get_union_args(x): + return getattr(x, '__args__', getattr(x, '__union_args__', None)) + def is_flex_subclass(x, cls): - if isinstance(x, _Any): + if x is Any: return True return issubclass(x, cls) @@ -37,9 +43,9 @@ def is_dispatch_func_subtype(f, input_cls, output_cls): return False # Case 2: fancy annotations: Union, generic TypeVar - if isinstance(res_type, _Union) and hasattr(res_type, '__args__'): + if is_union(res_type) and get_union_args(res_type): + sub_types = get_union_args(res_type) # passes if any unioned types are subclasses - sub_types = res_type.__args__ return any(map(lambda x: is_flex_subclass(x, output_cls), sub_types)) elif isinstance(res_type, TypeVar): if res_type == par_type0: