Skip to content

Commit

Permalink
Merge pull request #128 from scipp/decorated-providers
Browse files Browse the repository at this point in the history
Support decorated providers (again)
  • Loading branch information
jl-wynen authored Feb 16, 2024
2 parents f75b25d + 278e4a1 commit b8c19db
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 2 deletions.
135 changes: 135 additions & 0 deletions docs/recipes/applying-decorators.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "f7e2f743-54a3-4031-bb19-f2bd96d05de6",
"metadata": {},
"source": [
"# Applying decorators\n",
"\n",
"When using decorators with providers, care must be taken to allow Sciline to recognize the argument and return types.\n",
"This is done easiest with [functools.wraps](https://docs.python.org/3/library/functools.html#functools.wraps).\n",
"The following decorator can be safely applied to providers:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30ef6849-606e-4a32-aa75-e493bfdb0fcd",
"metadata": {},
"outputs": [],
"source": [
"import functools\n",
"from typing import Any, Callable, TypeVar\n",
"\n",
"import sciline\n",
"\n",
"R = TypeVar('R')\n",
"\n",
"\n",
"def deco(f: Callable[..., R]) -> Callable[..., R]:\n",
" @functools.wraps(f)\n",
" def impl(*args: Any, **kwargs: Any) -> R:\n",
" return f(*args, **kwargs)\n",
"\n",
" return impl\n",
"\n",
"@deco\n",
"def to_string(x: int) -> str:\n",
" return str(x)\n",
"\n",
"\n",
"pipeline = sciline.Pipeline([to_string], params={int: 3})\n",
"pipeline.compute(str)"
]
},
{
"cell_type": "markdown",
"id": "760f13d6-05be-4df0-90e3-348baa2dee8c",
"metadata": {},
"source": [
"Omitting `functools.wraps` results in an error when computing results:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad2adc47-635a-4fa8-b804-cdde880c33ea",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": [
"raises-exception"
]
},
"outputs": [],
"source": [
"def bad_deco(f: Callable[..., R]) -> Callable[..., R]:\n",
" def impl(*args: Any, **kwargs: Any) -> R:\n",
" return f(*args, **kwargs)\n",
"\n",
" return impl\n",
"\n",
"@bad_deco\n",
"def to_string(x: int) -> str:\n",
" return str(x)\n",
"\n",
"pipeline = sciline.Pipeline([to_string], params={int: 3})\n",
"pipeline.compute(str)"
]
},
{
"cell_type": "markdown",
"id": "9b93319a-e54b-478d-952e-d6d0fb35704e",
"metadata": {},
"source": [
"<div class=\"alert alert-info\">\n",
"\n",
"**Hint**\n",
"\n",
"For Python 3.10+, the decorator itself can be type-annoted like this:\n",
"\n",
"```python\n",
"from typing import ParamSpec, TypeVar\n",
"\n",
"P = ParamSpec('P')\n",
"R = TypeVar('R')\n",
"\n",
"def deco(f: Callable[P R]) -> Callable[P, R]:\n",
" @functools.wraps(f)\n",
" def impl(*args: P.args, **kwargs: P.kwargs) -> R:\n",
" return f(*args, **kwargs)\n",
"\n",
" return impl\n",
"```\n",
"\n",
"This is good practice but not required by Sciline.\n",
"\n",
"</div>"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
3 changes: 2 additions & 1 deletion docs/recipes/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
maxdepth: 2
---
side-effects-and-file-writing
applying-decorators
continue-from-intermediate-results
replacing-providers
side-effects-and-file-writing
```
8 changes: 7 additions & 1 deletion src/sciline/_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(self, *, args: dict[str, Key], kwargs: dict[str, Key]) -> None:
def from_function(cls, provider: ToProvider) -> ArgSpec:
"""Parse the argument spec of a provider."""
hints = get_type_hints(provider)
signature = inspect.getfullargspec(provider)
signature = _get_provider_args(provider)
try:
args = {name: hints[name] for name in signature.args}
kwargs = {name: hints[name] for name in signature.kwonlyargs}
Expand Down Expand Up @@ -257,3 +257,9 @@ def _bind_free_typevars(tp: Union[TypeVar, Key], bound: dict[TypeVar, Key]) -> K
def _module_name(x: Any) -> str:
# getmodule might return None
return getattr(inspect.getmodule(x), '__name__', '')


def _get_provider_args(func: ToProvider) -> inspect.FullArgSpec:
if (wrapped := getattr(func, '__wrapped__', None)) is not None:
return _get_provider_args(wrapped)
return inspect.getfullargspec(func)
83 changes: 83 additions & 0 deletions tests/_provider_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024 Scipp contributors (https://github.com/scipp)
# Tests on the hidden provider module

from typing import List, Tuple, TypeVar

from sciline._provider import ArgSpec


def test_arg_spec() -> None:
arg_spec = ArgSpec(args={'a': int}, kwargs={'b': str})
assert list(arg_spec.args) == [int]
assert dict(arg_spec.kwargs) == {'b': str}


def combine_numbers(a: int, *, b: float) -> str:
return f"{a} and {b}"


T = TypeVar('T', int, float)


def complicated_append(a: T, *, b: List[T]) -> Tuple[T, ...]:
b.append(a)

return tuple(b)


def test_arg_spec_from_function_simple() -> None:
arg_spec = ArgSpec.from_function(combine_numbers)
assert list(arg_spec.args) == [int]
assert dict(arg_spec.kwargs) == {'b': float}


def test_arg_spec_from_function_typevar() -> None:
arg_spec = ArgSpec.from_function(complicated_append)

assert list(arg_spec.args) == [T]
assert dict(arg_spec.kwargs) == {'b': List[T]} # type: ignore[valid-type]
specific_arg_spec = arg_spec.bind_type_vars(
bound={T: int} # type: ignore[dict-item]
)
assert list(specific_arg_spec.args) == [int]
assert dict(specific_arg_spec.kwargs) == {'b': list[int]}


def test_arg_spec_decorated_function_with_wraps() -> None:
from typing import Callable, Union

def decorator(func: Callable[..., str]) -> Callable[..., str]:
from functools import wraps

@wraps(func)
def wrapper(*args: Union[int, float], **kwargs: Union[int, float]) -> str:
return "Wrapped: " + func(*args, **kwargs)

return wrapper

@decorator
def decorated(a: int, *, b: float) -> str:
return f"{a} and {b}"

arg_spec = ArgSpec.from_function(decorated)
assert list(arg_spec.args) == [int]
assert dict(arg_spec.kwargs) == {'b': float}


def test_arg_spec_decorated_function_without_wraps() -> None:
from typing import Callable, Union

def decorator(func: Callable[..., str]) -> Callable[..., str]:
def wrapper(*args: Union[int, float], **kwargs: Union[int, float]) -> str:
return "Wrapped: " + func(*args, **kwargs)

return wrapper

@decorator
def decorated(a: int, *, b: float) -> str:
return f"{a} and {b}"

arg_spec = ArgSpec.from_function(decorated)
assert list(arg_spec.args) == []
assert dict(arg_spec.kwargs) == {}
35 changes: 35 additions & 0 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
import functools
from dataclasses import dataclass
from typing import Any, Callable, Generic, List, NewType, TypeVar

Expand Down Expand Up @@ -1404,3 +1405,37 @@ def p2(p: Preference) -> Person[Preference, Green]:
with pytest.raises(sl.AmbiguousProvider):
# provided by p1 and p2 with the same number of typevars
pipeline.get(Person[Likes[Green], Green])


def test_pipeline_with_decorated_provider() -> None:
R = TypeVar('R')

def deco(f: Callable[..., R]) -> Callable[..., R]:
@functools.wraps(f)
def impl(*args: Any, **kwargs: Any) -> R:
return f(*args, **kwargs)

return impl

provider = deco(int_to_float)
pipeline = sl.Pipeline([provider], params={int: 3})
assert pipeline.compute(float) == 1.5


def test_pipeline_with_decorated_provider_keyword_only_arg() -> None:
R = TypeVar('R')

def deco(f: Callable[..., R]) -> Callable[..., R]:
@functools.wraps(f)
def impl(*args: Any, **kwargs: Any) -> R:
return f(*args, **kwargs)

return impl

@deco
def foo(*, k: int) -> float:
return float(k)

provider = deco(foo)
pipeline = sl.Pipeline([provider], params={int: 3})
assert pipeline.compute(float) == 3.0

0 comments on commit b8c19db

Please sign in to comment.