Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add maybe_apply function #223

Merged
merged 8 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ dev = [
"furo",
"invoke",
"mypy",
"pytest",
# pytest 8+ is not supported by pytest-mypy-testing
"pytest <8",
"pytest-mypy-testing",
"pytest-cov",
"sphinx",
Expand Down
11 changes: 9 additions & 2 deletions src/pydash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
zip_object_deep,
zip_with,
)
from .chaining import _Dash, chain, tap, thru
from .chaining import _Dash, chain, tap
from .collections import (
at,
count_by,
Expand Down Expand Up @@ -168,6 +168,10 @@
zscore,
)
from .objects import (
apply,
apply_catch,
apply_if,
apply_if_not_none,
assign,
assign_with,
callables,
Expand Down Expand Up @@ -462,7 +466,6 @@
"_Dash",
"chain",
"tap",
"thru",
"at",
"count_by",
"every",
Expand Down Expand Up @@ -544,6 +547,10 @@
"transpose",
"variance",
"zscore",
"apply",
"apply_catch",
"apply_if",
"apply_if_not_none",
"assign",
"assign_with",
"callables",
Expand Down
3 changes: 1 addition & 2 deletions src/pydash/chaining/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from .chaining import _Dash, chain, tap, thru
from .chaining import _Dash, chain, tap


__all__ = (
"_Dash",
"chain",
"tap",
"thru",
)
33 changes: 30 additions & 3 deletions src/pydash/chaining/all_funcs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,6 @@ class AllFuncs:
def tap(self: "Chain[T]", interceptor: t.Callable[[T], t.Any]) -> "Chain[T]":
return self._wrap(pyd.tap)(interceptor)

def thru(self: "Chain[T]", interceptor: t.Callable[[T], T2]) -> "Chain[T2]":
return self._wrap(pyd.thru)(interceptor)

@t.overload
def at(self: "Chain[t.Mapping[T, T2]]", *paths: T) -> "Chain[t.List[t.Union[T2, None]]]": ...
@t.overload
Expand Down Expand Up @@ -2578,6 +2575,36 @@ class AllFuncs:
) -> "Chain[t.Any]":
return self._wrap(pyd.map_values_deep)(iteratee, property_path)

def apply(self: "Chain[T]", func: t.Callable[[T], T2]) -> "Chain[T2]":
return self._wrap(pyd.apply)(func)

def apply_if(
self: "Chain[T]", func: t.Callable[[T], T2], predicate: t.Callable[[T], bool]
) -> "Chain[t.Union[T, T2]]":
return self._wrap(pyd.apply_if)(func, predicate)

def apply_if_not_none(
self: "Chain[t.Optional[T]]", func: t.Callable[[T], T2]
) -> "Chain[t.Optional[T2]]":
return self._wrap(pyd.apply_if_not_none)(func)

@t.overload
def apply_catch(
self: "Chain[T]",
func: t.Callable[[T], T2],
exceptions: t.Iterable[t.Type[Exception]],
default: T3,
) -> "Chain[t.Union[T2, T3]]": ...
@t.overload
def apply_catch(
self: "Chain[T]",
func: t.Callable[[T], T2],
exceptions: t.Iterable[t.Type[Exception]],
default: Unset = UNSET,
) -> "Chain[t.Union[T, T2]]": ...
def apply_catch(self, func, exceptions, default=UNSET):
return self._wrap(pyd.apply_catch)(func, exceptions, default)

@t.overload
def merge(
self: "Chain[t.Mapping[T, T2]]", *sources: t.Mapping[T3, T4]
Expand Down
23 changes: 0 additions & 23 deletions src/pydash/chaining/chaining.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
__all__ = (
"chain",
"tap",
"thru",
)

ValueT_co = t.TypeVar("ValueT_co", covariant=True)
Expand Down Expand Up @@ -263,25 +262,3 @@ def tap(value: T, interceptor: t.Callable[[T], t.Any]) -> T:
"""
interceptor(value)
return value


def thru(value: T, interceptor: t.Callable[[T], T2]) -> T2:
"""
Returns the result of calling `interceptor` on `value`. The purpose of this method is to pass
`value` through a function during a method chain.

Args:
value: Current value of chain operation.
interceptor: Function called with `value`.

Returns:
Results of ``interceptor(value)``.

Example:

>>> chain([1, 2, 3, 4]).thru(lambda x: x * 2).value()
[1, 2, 3, 4, 1, 2, 3, 4]

.. versionadded:: 2.0.0
"""
return interceptor(value)
121 changes: 120 additions & 1 deletion src/pydash/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import pydash as pyd

from .helpers import UNSET, base_get, base_set, callit, getargcount, iterator, iteriteratee
from .helpers import UNSET, Unset, base_get, base_set, callit, getargcount, iterator, iteriteratee
from .types import IterateeObjT, PathT
from .utilities import PathToken, to_path, to_path_tokens

Expand All @@ -21,6 +21,10 @@
from _typeshed import SupportsRichComparisonT # pragma: no cover

__all__ = (
"apply",
"apply_catch",
"apply_if",
"apply_if_not_none",
"assign",
"assign_with",
"callables",
Expand Down Expand Up @@ -1251,6 +1255,121 @@ def deep_iteratee(value, key):
return callit(iteratee, obj, properties)


def apply(obj: T, func: t.Callable[[T], T2]) -> T2:
"""
Returns the result of calling `func` on `obj`. Particularly useful to pass
`obj` through a function during a method chain.

Args:
obj: Object to apply function to
func: Function called with `obj`.

Returns:
Results of ``func(value)``.

Example:

>>> apply(5, lambda x: x * 2)
10

.. versionadded:: 8.0.0
"""
return func(obj)


def apply_if(obj: T, func: t.Callable[[T], T2], predicate: t.Callable[[T], bool]) -> t.Union[T, T2]:
"""
Apply `func` to `obj` if `predicate` returns `True`.

Args:
obj: Object to apply `func` to.
func: Function to apply to `obj`.
predicate: Predicate applied to `obj`.

Returns:
Result of applying `func` to `obj` or `obj`.

Example:

>>> apply_if(2, lambda x: x * 2, lambda x: x > 1)
4
>>> apply_if(2, lambda x: x * 2, lambda x: x < 1)
2

.. versionadded:: 8.0.0
"""
return func(obj) if predicate(obj) else obj


def apply_if_not_none(obj: t.Optional[T], func: t.Callable[[T], T2]) -> t.Optional[T2]:
"""
Apply `func` to `obj` if `obj` is not ``None``.

Args:
obj: Object to apply `func` to.
func: Function to apply to `obj`.

Returns:
Result of applying `func` to `obj` or ``None``.

Example:

>>> apply_if_not_none(2, lambda x: x * 2)
4
>>> apply_if_not_none(None, lambda x: x * 2) is None
True

.. versionadded:: 8.0.0
"""
return apply_if(obj, func, lambda x: x is not None) # type: ignore


@t.overload
def apply_catch(
obj: T, func: t.Callable[[T], T2], exceptions: t.Iterable[t.Type[Exception]], default: T3
) -> t.Union[T2, T3]: ...


@t.overload
def apply_catch(
obj: T,
func: t.Callable[[T], T2],
exceptions: t.Iterable[t.Type[Exception]],
default: Unset = UNSET,
) -> t.Union[T, T2]: ...


def apply_catch(obj, func, exceptions, default=UNSET):
"""
Tries to apply `func` to `obj` if any of the exceptions in `excs` are raised, return `default`
or `obj` if not set.

Args:
obj: Object to apply `func` to.
func: Function to apply to `obj`.
excs: Exceptions to catch.
default: Value to return if exception is raised.

Returns:
Result of applying `func` to `obj` or ``default``.

Example:

>>> apply_catch(2, lambda x: x * 2, [ValueError])
4
>>> apply_catch(2, lambda x: x / 0, [ZeroDivisionError], "error")
'error'
>>> apply_catch(2, lambda x: x / 0, [ZeroDivisionError])
2

.. versionadded:: 8.0.0
"""
try:
return func(obj)
except tuple(exceptions):
return obj if default is UNSET else default


@t.overload
def merge(
obj: t.Mapping[T, T2], *sources: t.Mapping[T3, T4]
Expand Down
5 changes: 0 additions & 5 deletions tests/pytest_mypy_testing/test_chaining.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,3 @@ def log(value):
data.append(value)

reveal_type(_.chain([1, 2, 3, 4]).map(lambda x: x * 2).tap(log).value()) # R: builtins.list[builtins.int]


@pytest.mark.mypy_testing
def test_mypy_thru() -> None:
reveal_type(_.chain([1, 2, 3, 4]).thru(lambda x: x * 2).value()) # R: builtins.list[builtins.int]
25 changes: 25 additions & 0 deletions tests/pytest_mypy_testing/test_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,28 @@ def test_mypy_values() -> None:
reveal_type(_.values({'a': 1, 'b': 2, 'c': 3})) # R: builtins.list[builtins.int]
reveal_type(_.values([2, 4, 6, 8])) # R: builtins.list[builtins.int]
reveal_type(_.values(MyClass())) # R: builtins.list[Any]


@pytest.mark.mypy_testing
def test_mypy_apply() -> None:
reveal_type(_.apply("1", lambda x: int(x))) # R: builtins.int
reveal_type(_.apply(1, lambda x: x + 1)) # R: builtins.int
reveal_type(_.apply("hello", lambda x: x.upper())) # R: builtins.str


@pytest.mark.mypy_testing
def test_mypy_apply_if() -> None:
reveal_type(_.apply_if("5", lambda x: int(x), lambda x: x.isdecimal())) # R: Union[builtins.str, builtins.int]


@pytest.mark.mypy_testing
def test_mypy_apply_if_not_none() -> None:
reveal_type(_.apply_if_not_none(1, lambda x: x + 1)) # R: Union[builtins.int, None]
reveal_type(_.apply_if_not_none(None, lambda x: x + 1)) # R: Union[builtins.int, None]
reveal_type(_.apply_if_not_none("hello", lambda x: x.upper())) # R: Union[builtins.str, None]


@pytest.mark.mypy_testing
def test_mypy_apply_catch() -> None:
reveal_type(_.apply_catch(5, lambda x: x / 0, [ZeroDivisionError])) # R: Union[builtins.int, builtins.float]
reveal_type(_.apply_catch(5, lambda x: x / 0, [ZeroDivisionError], "error")) # R: Union[builtins.float, builtins.str]
5 changes: 0 additions & 5 deletions tests/test_chaining.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,3 @@ def test_chaining_value_to_string(case, expected):
def test_tap(value, interceptor, expected):
actual = _.chain(value).initial().tap(interceptor).last().value()
assert actual == expected


@parametrize("value,func,expected", [([1, 2, 3, 4, 5], lambda value: [sum(value)], 10)])
def test_thru(value, func, expected):
assert _.chain(value).initial().thru(func).last().value()
26 changes: 26 additions & 0 deletions tests/test_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,3 +944,29 @@ def test_unset(obj, path, expected, new_obj):
@parametrize("case,expected", [({"a": 1, "b": 2, "c": 3}, [1, 2, 3]), ([1, 2, 3], [1, 2, 3])])
def test_values(case, expected):
assert set(_.values(case)) == set(expected)


@parametrize("case,expected", [((5, lambda x: x * 2), 10)])
def test_apply(case, expected):
assert _.apply(*case) == expected


@parametrize(
"case,expected",
[((5, lambda x: x * 2, lambda x: x == 5), 10), ((5, lambda x: x * 2, lambda x: x == 10), 5)],
)
def test_apply_if(case, expected):
assert _.apply_if(*case) == expected


@parametrize("case,expected", [((5, lambda x: x * 2), 10), ((None, lambda x: x * 2), None)])
def test_apply_if_not_none(case, expected):
assert _.apply_if_not_none(*case) == expected


@parametrize(
"case,expected",
[((5, lambda x: x * 2, [ValueError]), 10), ((5, lambda x: x / 0, [ZeroDivisionError]), 5)],
)
def test_apply_catch(case, expected):
assert _.apply_catch(*case) == expected