Skip to content

Commit

Permalink
Improve function composition
Browse files Browse the repository at this point in the history
Resolves   #744.
evhub committed May 20, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 861fda4 commit b061817
Showing 7 changed files with 158 additions and 33 deletions.
2 changes: 2 additions & 0 deletions DOCS.md
Original file line number Diff line number Diff line change
@@ -726,6 +726,8 @@ The `..` operator has lower precedence than `::` but higher precedence than infi

All function composition operators also have in-place versions (e.g. `..=`).

Since all forms of function composition always call the first function in the composition (`f` in `f ..> g` and `g` in `f <.. g`) with exactly the arguments passed into the composition, all forms of function composition will preserve all metadata attached to the first function in the composition, including the function's [signature](https://docs.python.org/3/library/inspect.html#inspect.signature) and any of that function's attributes.

##### Example

**Coconut:**
23 changes: 17 additions & 6 deletions __coconut__/__init__.pyi
Original file line number Diff line number Diff line change
@@ -169,23 +169,29 @@ enumerate = enumerate

_coconut_py_str = py_str
_coconut_super = super
_coconut_enumerate = enumerate
_coconut_filter = filter
_coconut_range = range
_coconut_reversed = reversed
_coconut_zip = zip


zip_longest = _coconut.zip_longest
memoize = _lru_cache


reduce = _coconut.functools.reduce
takewhile = _coconut.itertools.takewhile
dropwhile = _coconut.itertools.dropwhile
tee = _coconut_tee = _coconut.itertools.tee
starmap = _coconut_starmap = _coconut.itertools.starmap
tee = _coconut.itertools.tee
starmap = _coconut.itertools.starmap
cartesian_product = _coconut.itertools.product
multiset = _coconut_multiset = _coconut.collections.Counter

multiset = _coconut.collections.Counter

_coconut_tee = tee
_coconut_starmap = starmap
_coconut_cartesian_product = cartesian_product
_coconut_multiset = multiset


parallel_map = concurrent_map = _coconut_map = map


@@ -200,6 +206,7 @@ def scan(
iterable: _t.Iterable[_U],
initial: _T = ...,
) -> _t.Iterable[_T]: ...
_coconut_scan = scan


class MatchError(Exception):
@@ -968,6 +975,7 @@ class cycle(_t.Iterable[_T]):
def __fmap__(self, func: _t.Callable[[_T], _U]) -> _t.Iterable[_U]: ...
def __copy__(self) -> cycle[_T]: ...
def __len__(self) -> int: ...
_coconut_cycle = cycle


class groupsof(_t.Generic[_T]):
@@ -981,6 +989,7 @@ class groupsof(_t.Generic[_T]):
def __copy__(self) -> groupsof[_T]: ...
def __len__(self) -> int: ...
def __fmap__(self, func: _t.Callable[[_t.Tuple[_T, ...]], _U]) -> _t.Iterable[_U]: ...
_coconut_groupsof = groupsof


class windowsof(_t.Generic[_T]):
@@ -996,6 +1005,7 @@ class windowsof(_t.Generic[_T]):
def __copy__(self) -> windowsof[_T]: ...
def __len__(self) -> int: ...
def __fmap__(self, func: _t.Callable[[_t.Tuple[_T, ...]], _U]) -> _t.Iterable[_U]: ...
_coconut_windowsof = windowsof


class flatten(_t.Iterable[_T]):
@@ -1228,6 +1238,7 @@ def lift(func: _t.Callable[[_T, _U], _W]) -> _coconut_lifted_2[_T, _U, _W]: ...
def lift(func: _t.Callable[[_T, _U, _V], _W]) -> _coconut_lifted_3[_T, _U, _V, _W]: ...
@_t.overload
def lift(func: _t.Callable[..., _W]) -> _t.Callable[..., _t.Callable[..., _W]]: ...
_coconut_lift = lift


def all_equal(iterable: _Iterable) -> bool: ...
144 changes: 118 additions & 26 deletions coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
@@ -184,7 +184,7 @@ def tee(iterable, n=2):
class _coconut_has_iter(_coconut_baseclass):
__slots__ = ("lock", "iter")
def __new__(cls, iterable):
self = _coconut.object.__new__(cls)
self = _coconut.super(_coconut_has_iter, cls).__new__(cls)
self.lock = _coconut.threading.Lock()
self.iter = iterable
return self
@@ -201,7 +201,7 @@ class reiterable(_coconut_has_iter):
def __new__(cls, iterable):
if _coconut.isinstance(iterable, _coconut.reiterables):
return iterable
return _coconut_has_iter.__new__(cls, iterable)
return _coconut.super({_coconut_}reiterable, cls).__new__(cls, iterable)
def get_new_iter(self):
"""Tee the underlying iterator."""
with self.lock:
@@ -331,21 +331,28 @@ def _coconut_iter_getitem(iterable, index):
return ()
iterable = _coconut.itertools.islice(iterable, 0, n)
return _coconut.tuple(iterable)[i::step]
class _coconut_base_compose(_coconut_baseclass):
__slots__ = ("func", "func_infos")
class _coconut_base_compose(_coconut_baseclass):{COMMENT.no_slots_to_allow_update_wrapper}{COMMENT.must_use_coconut_attrs_to_avoid_interacting_with_update_wrapper}
def __init__(self, func, *func_infos):
self.func = func
self.func_infos = []
try:
_coconut.functools.update_wrapper(self, func)
except _coconut.AttributeError:
pass
if _coconut.isinstance(func, _coconut_base_compose):
self._coconut_func = func._coconut_func
func_infos = func._coconut_func_infos + func_infos
else:
self._coconut_func = func
self._coconut_func_infos = []
for f, stars, none_aware in func_infos:
if _coconut.isinstance(f, _coconut_base_compose):
self.func_infos.append((f.func, stars, none_aware))
self.func_infos += f.func_infos
self._coconut_func_infos.append((f._coconut_func, stars, none_aware))
self._coconut_func_infos += f._coconut_func_infos
else:
self.func_infos.append((f, stars, none_aware))
self.func_infos = _coconut.tuple(self.func_infos)
self._coconut_func_infos.append((f, stars, none_aware))
self._coconut_func_infos = _coconut.tuple(self._coconut_func_infos)
def __call__(self, *args, **kwargs):
arg = self.func(*args, **kwargs)
for f, stars, none_aware in self.func_infos:
arg = self._coconut_func(*args, **kwargs)
for f, stars, none_aware in self._coconut_func_infos:
if none_aware and arg is None:
return arg
if stars == 0:
@@ -358,9 +365,9 @@ class _coconut_base_compose(_coconut_baseclass):
raise _coconut.RuntimeError("invalid internal stars value " + _coconut.repr(stars) + " in " + _coconut.repr(self) + " {report_this_text}")
return arg
def __repr__(self):
return _coconut.repr(self.func) + " " + " ".join(".." + "?"*none_aware + "*"*stars + "> " + _coconut.repr(f) for f, stars, none_aware in self.func_infos)
return _coconut.repr(self._coconut_func) + " " + " ".join(".." + "?"*none_aware + "*"*stars + "> " + _coconut.repr(f) for f, stars, none_aware in self._coconut_func_infos)
def __reduce__(self):
return (self.__class__, (self.func,) + self.func_infos)
return (self.__class__, (self._coconut_func,) + self._coconut_func_infos)
def __get__(self, obj, objtype=None):
if obj is None:
return self
@@ -501,7 +508,7 @@ class scan(_coconut_has_iter):
optionally starting from initial."""
__slots__ = ("func", "initial")
def __new__(cls, function, iterable, initial=_coconut_sentinel):
self = _coconut_has_iter.__new__(cls, iterable)
self = _coconut.super({_coconut_}scan, cls).__new__(cls, iterable)
self.func = function
self.initial = initial
return self
@@ -532,8 +539,7 @@ class reversed(_coconut_has_iter):
if _coconut.isinstance(iterable, _coconut.range):
return iterable[::-1]
if _coconut.getattr(iterable, "__reversed__", None) is None or _coconut.isinstance(iterable, (_coconut.list, _coconut.tuple)):
self = _coconut_has_iter.__new__(cls, iterable)
return self
return _coconut.super({_coconut_}reversed, cls).__new__(cls, iterable)
return _coconut.reversed(iterable)
def __repr__(self):
return "reversed(%s)" % (_coconut.repr(self.iter),)
@@ -574,7 +580,7 @@ class flatten(_coconut_has_iter):{COMMENT.cant_implement_len_else_list_calls_bec
raise _coconut.ValueError("flatten: levels cannot be negative")
if levels == 0:
return iterable
self = _coconut_has_iter.__new__(cls, iterable)
self = _coconut.super({_coconut_}flatten, cls).__new__(cls, iterable)
self.levels = levels
self._made_reit = False
return self
@@ -673,7 +679,7 @@ Additionally supports Cartesian products of numpy arrays."""
for i, a in _coconut.enumerate(numpy.ix_(*iterables)):
arr[..., i] = a
return arr.reshape(-1, _coconut.len(iterables))
self = _coconut.object.__new__(cls)
self = _coconut.super({_coconut_}cartesian_product, cls).__new__(cls)
self.iters = iterables
self.repeat = repeat
return self
@@ -775,7 +781,7 @@ class _coconut_base_parallel_concurrent_map(map):
def get_pool_stack(cls):
return cls.threadlocal_ns.__dict__.setdefault("pool_stack", [None])
def __new__(cls, function, *iterables, **kwargs):
self = {_coconut_}map.__new__(cls, function, *iterables)
self = _coconut.super(_coconut_base_parallel_concurrent_map, cls).__new__(cls, function, *iterables)
self.result = None
self.chunksize = kwargs.pop("chunksize", 1)
self.strict = kwargs.pop("strict", False)
@@ -870,7 +876,7 @@ class zip_longest(zip):
__slots__ = ("fillvalue",)
__doc__ = getattr(_coconut.zip_longest, "__doc__", "Version of zip that fills in missing values with fillvalue.")
def __new__(cls, *iterables, **kwargs):
self = {_coconut_}zip.__new__(cls, *iterables, strict=False)
self = _coconut.super({_coconut_}zip_longest, cls).__new__(cls, *iterables, strict=False)
self.fillvalue = kwargs.pop("fillvalue", None)
if kwargs:
raise _coconut.TypeError(cls.__name__ + "() got unexpected keyword arguments " + _coconut.repr(kwargs))
@@ -1081,7 +1087,7 @@ class cycle(_coconut_has_iter):
before stopping."""
__slots__ = ("times",)
def __new__(cls, iterable, times=None):
self = _coconut_has_iter.__new__(cls, iterable)
self = _coconut.super({_coconut_}cycle, cls).__new__(cls, iterable)
if times is None:
self.times = None
else:
@@ -1136,7 +1142,7 @@ class windowsof(_coconut_has_iter):
If that is not the desired behavior, fillvalue can be passed and will be used in place of missing values."""
__slots__ = ("size", "fillvalue", "step")
def __new__(cls, size, iterable, fillvalue=_coconut_sentinel, step=1):
self = _coconut_has_iter.__new__(cls, iterable)
self = _coconut.super({_coconut_}windowsof, cls).__new__(cls, iterable)
self.size = _coconut.operator.index(size)
if self.size < 1:
raise _coconut.ValueError("windowsof: size must be >= 1; not %r" % (self.size,))
@@ -1178,7 +1184,7 @@ class groupsof(_coconut_has_iter):
"""
__slots__ = ("group_size", "fillvalue")
def __new__(cls, n, iterable, fillvalue=_coconut_sentinel):
self = _coconut_has_iter.__new__(cls, iterable)
self = _coconut.super({_coconut_}groupsof, cls).__new__(cls, iterable)
self.group_size = _coconut.operator.index(n)
if self.group_size < 1:
raise _coconut.ValueError("group size must be >= 1; not %r" % (self.group_size,))
@@ -1755,7 +1761,7 @@ class lift(_coconut_baseclass):
"""
__slots__ = ("func",)
def __new__(cls, func, *func_args, **func_kwargs):
self = _coconut.object.__new__(cls)
self = _coconut.super({_coconut_}lift, cls).__new__(cls)
self.func = func
if func_args or func_kwargs:
self = self(*func_args, **func_kwargs)
@@ -1879,48 +1885,134 @@ def _coconut_call_or_coefficient(func, *args):
func = func * x{COMMENT.no_times_equals_to_avoid_modification}
return func
class _coconut_SupportsAdd(_coconut.typing.Protocol):
"""Coconut (+) Protocol. Equivalent to:

class SupportsAdd[T, U, V](Protocol):
def __add__(self: T, other: U) -> V:
raise NotImplementedError(...)
"""
def __add__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((+) in a typing context is a Protocol)")
class _coconut_SupportsMinus(_coconut.typing.Protocol):
"""Coconut (-) Protocol. Equivalent to:

class SupportsMinus[T, U, V](Protocol):
def __sub__(self: T, other: U) -> V:
raise NotImplementedError
def __neg__(self: T) -> V:
raise NotImplementedError
"""
def __sub__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((-) in a typing context is a Protocol)")
def __neg__(self):
raise NotImplementedError("Protocol methods cannot be called at runtime ((-) in a typing context is a Protocol)")
class _coconut_SupportsMul(_coconut.typing.Protocol):
"""Coconut (*) Protocol. Equivalent to:

class SupportsMul[T, U, V](Protocol):
def __mul__(self: T, other: U) -> V:
raise NotImplementedError(...)
"""
def __mul__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((*) in a typing context is a Protocol)")
class _coconut_SupportsPow(_coconut.typing.Protocol):
"""Coconut (**) Protocol. Equivalent to:

class SupportsPow[T, U, V](Protocol):
def __pow__(self: T, other: U) -> V:
raise NotImplementedError(...)
"""
def __pow__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((**) in a typing context is a Protocol)")
class _coconut_SupportsTruediv(_coconut.typing.Protocol):
"""Coconut (/) Protocol. Equivalent to:

class SupportsTruediv[T, U, V](Protocol):
def __truediv__(self: T, other: U) -> V:
raise NotImplementedError(...)
"""
def __truediv__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((/) in a typing context is a Protocol)")
class _coconut_SupportsFloordiv(_coconut.typing.Protocol):
"""Coconut (//) Protocol. Equivalent to:

class SupportsFloordiv[T, U, V](Protocol):
def __floordiv__(self: T, other: U) -> V:
raise NotImplementedError(...)
"""
def __floordiv__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((//) in a typing context is a Protocol)")
class _coconut_SupportsMod(_coconut.typing.Protocol):
"""Coconut (%) Protocol. Equivalent to:

class SupportsMod[T, U, V](Protocol):
def __mod__(self: T, other: U) -> V:
raise NotImplementedError(...)
"""
def __mod__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((%) in a typing context is a Protocol)")
class _coconut_SupportsAnd(_coconut.typing.Protocol):
"""Coconut (&) Protocol. Equivalent to:

class SupportsAnd[T, U, V](Protocol):
def __and__(self: T, other: U) -> V:
raise NotImplementedError(...)
"""
def __and__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((&) in a typing context is a Protocol)")
class _coconut_SupportsXor(_coconut.typing.Protocol):
"""Coconut (^) Protocol. Equivalent to:

class SupportsXor[T, U, V](Protocol):
def __xor__(self: T, other: U) -> V:
raise NotImplementedError(...)
"""
def __xor__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((^) in a typing context is a Protocol)")
class _coconut_SupportsOr(_coconut.typing.Protocol):
"""Coconut (|) Protocol. Equivalent to:

class SupportsOr[T, U, V](Protocol):
def __or__(self: T, other: U) -> V:
raise NotImplementedError(...)
"""
def __or__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((|) in a typing context is a Protocol)")
class _coconut_SupportsLshift(_coconut.typing.Protocol):
"""Coconut (<<) Protocol. Equivalent to:

class SupportsLshift[T, U, V](Protocol):
def __lshift__(self: T, other: U) -> V:
raise NotImplementedError(...)
"""
def __lshift__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((<<) in a typing context is a Protocol)")
class _coconut_SupportsRshift(_coconut.typing.Protocol):
"""Coconut (>>) Protocol. Equivalent to:

class SupportsRshift[T, U, V](Protocol):
def __rshift__(self: T, other: U) -> V:
raise NotImplementedError(...)
"""
def __rshift__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((>>) in a typing context is a Protocol)")
class _coconut_SupportsMatmul(_coconut.typing.Protocol):
"""Coconut (@) Protocol. Equivalent to:

class SupportsMatmul[T, U, V](Protocol):
def __matmul__(self: T, other: U) -> V:
raise NotImplementedError(...)
"""
def __matmul__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((@) in a typing context is a Protocol)")
class _coconut_SupportsInv(_coconut.typing.Protocol):
"""Coconut (~) Protocol. Equivalent to:

class SupportsInv[T, V](Protocol):
def __invert__(self: T) -> V:
raise NotImplementedError(...)
"""
def __invert__(self):
raise NotImplementedError("Protocol methods cannot be called at runtime ((~) in a typing context is a Protocol)")
_coconut_self_match_types = {self_match_types}
_coconut_Expected, _coconut_MatchError, _coconut_count, _coconut_enumerate, _coconut_flatten, _coconut_filter, _coconut_ident, _coconut_map, _coconut_multiset, _coconut_range, _coconut_reiterable, _coconut_reversed, _coconut_starmap, _coconut_tee, _coconut_zip, TYPE_CHECKING, reduce, takewhile, dropwhile = Expected, MatchError, count, enumerate, flatten, filter, ident, map, multiset, range, reiterable, reversed, starmap, tee, zip, False, _coconut.functools.reduce, _coconut.itertools.takewhile, _coconut.itertools.dropwhile
_coconut_Expected, _coconut_MatchError, _coconut_cartesian_product, _coconut_count, _coconut_cycle, _coconut_enumerate, _coconut_flatten, _coconut_filter, _coconut_groupsof, _coconut_ident, _coconut_lift, _coconut_map, _coconut_multiset, _coconut_range, _coconut_reiterable, _coconut_reversed, _coconut_scan, _coconut_starmap, _coconut_tee, _coconut_windowsof, _coconut_zip, _coconut_zip_longest, TYPE_CHECKING, reduce, takewhile, dropwhile = Expected, MatchError, cartesian_product, count, cycle, enumerate, flatten, filter, groupsof, ident, lift, map, multiset, range, reiterable, reversed, scan, starmap, tee, windowsof, zip, zip_longest, False, _coconut.functools.reduce, _coconut.itertools.takewhile, _coconut.itertools.dropwhile{COMMENT.anything_added_here_should_be_copied_to_stub_file}
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@
VERSION = "3.0.0"
VERSION_NAME = None
# False for release, int >= 1 for develop
DEVELOP = 5
DEVELOP = 6
ALPHA = False # for pre releases rather than post releases

assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1"
3 changes: 3 additions & 0 deletions coconut/tests/src/cocotest/agnostic/main.coco
Original file line number Diff line number Diff line change
@@ -56,6 +56,7 @@ def run_main(outer_MatchError, test_easter_eggs=False) -> bool:
non_py26_test,
non_py32_test,
py3_spec_test,
py33_spec_test,
py36_spec_test,
py37_spec_test,
py38_spec_test,
@@ -66,6 +67,8 @@ def run_main(outer_MatchError, test_easter_eggs=False) -> bool:
assert non_py32_test() is True
if sys.version_info >= (3,):
assert py3_spec_test() is True
if sys.version_info >= (3, 3):
assert py33_spec_test() is True
if sys.version_info >= (3, 6):
assert py36_spec_test(tco=using_tco) is True
if sys.version_info >= (3, 7):
4 changes: 4 additions & 0 deletions coconut/tests/src/cocotest/agnostic/primary.coco
Original file line number Diff line number Diff line change
@@ -1590,4 +1590,8 @@ def primary_test() -> bool:
assert {"a":0, "b":1}$[0] == "a"
assert (|0, NotImplemented, 2|)$[1] is NotImplemented
assert m{1, 1, 2} |> fmap$(.+1) == m{2, 2, 3}
assert (+) ..> ((*) ..> (/)) == (+) ..> (*) ..> (/) == ((+) ..> (*)) ..> (/)
def f(x, y=1) = x, y # type: ignore
f.is_f = True # type: ignore
assert (f ..*> (+)).is_f # type: ignore
return True
13 changes: 13 additions & 0 deletions coconut/tests/src/cocotest/agnostic/specific.coco
Original file line number Diff line number Diff line change
@@ -44,6 +44,19 @@ def py3_spec_test() -> bool:
return True


def py33_spec_test() -> bool:
"""Tests for any py33+ version."""
from inspect import signature
def f(x, y=1) = x, y
def g(a, b=2) = a, b
assert signature(f ..*> g) == signature(f) == signature(f ..> g)
assert signature(f <*.. g) == signature(g) == signature(f <.. g)
assert signature(f$(0) ..> g) == signature(f$(0))
assert signature(f ..*> (+)) == signature(f)
assert signature((f ..*> g) ..*> g) == signature(f)
return True


def py36_spec_test(tco: bool) -> bool:
"""Tests for any py36+ version."""
from dataclasses import dataclass

0 comments on commit b061817

Please sign in to comment.