-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #128 from scipp/decorated-providers
Support decorated providers (again)
- Loading branch information
Showing
5 changed files
with
262 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f7e2f743-54a3-4031-bb19-f2bd96d05de6", | ||
"metadata": {}, | ||
"source": [ | ||
"# Applying decorators\n", | ||
"\n", | ||
"When using decorators with providers, care must be taken to allow Sciline to recognize the argument and return types.\n", | ||
"This is done easiest with [functools.wraps](https://docs.python.org/3/library/functools.html#functools.wraps).\n", | ||
"The following decorator can be safely applied to providers:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "30ef6849-606e-4a32-aa75-e493bfdb0fcd", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import functools\n", | ||
"from typing import Any, Callable, TypeVar\n", | ||
"\n", | ||
"import sciline\n", | ||
"\n", | ||
"R = TypeVar('R')\n", | ||
"\n", | ||
"\n", | ||
"def deco(f: Callable[..., R]) -> Callable[..., R]:\n", | ||
" @functools.wraps(f)\n", | ||
" def impl(*args: Any, **kwargs: Any) -> R:\n", | ||
" return f(*args, **kwargs)\n", | ||
"\n", | ||
" return impl\n", | ||
"\n", | ||
"@deco\n", | ||
"def to_string(x: int) -> str:\n", | ||
" return str(x)\n", | ||
"\n", | ||
"\n", | ||
"pipeline = sciline.Pipeline([to_string], params={int: 3})\n", | ||
"pipeline.compute(str)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "760f13d6-05be-4df0-90e3-348baa2dee8c", | ||
"metadata": {}, | ||
"source": [ | ||
"Omitting `functools.wraps` results in an error when computing results:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ad2adc47-635a-4fa8-b804-cdde880c33ea", | ||
"metadata": { | ||
"editable": true, | ||
"slideshow": { | ||
"slide_type": "" | ||
}, | ||
"tags": [ | ||
"raises-exception" | ||
] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"def bad_deco(f: Callable[..., R]) -> Callable[..., R]:\n", | ||
" def impl(*args: Any, **kwargs: Any) -> R:\n", | ||
" return f(*args, **kwargs)\n", | ||
"\n", | ||
" return impl\n", | ||
"\n", | ||
"@bad_deco\n", | ||
"def to_string(x: int) -> str:\n", | ||
" return str(x)\n", | ||
"\n", | ||
"pipeline = sciline.Pipeline([to_string], params={int: 3})\n", | ||
"pipeline.compute(str)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9b93319a-e54b-478d-952e-d6d0fb35704e", | ||
"metadata": {}, | ||
"source": [ | ||
"<div class=\"alert alert-info\">\n", | ||
"\n", | ||
"**Hint**\n", | ||
"\n", | ||
"For Python 3.10+, the decorator itself can be type-annoted like this:\n", | ||
"\n", | ||
"```python\n", | ||
"from typing import ParamSpec, TypeVar\n", | ||
"\n", | ||
"P = ParamSpec('P')\n", | ||
"R = TypeVar('R')\n", | ||
"\n", | ||
"def deco(f: Callable[P R]) -> Callable[P, R]:\n", | ||
" @functools.wraps(f)\n", | ||
" def impl(*args: P.args, **kwargs: P.kwargs) -> R:\n", | ||
" return f(*args, **kwargs)\n", | ||
"\n", | ||
" return impl\n", | ||
"```\n", | ||
"\n", | ||
"This is good practice but not required by Sciline.\n", | ||
"\n", | ||
"</div>" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.18" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) | ||
# Tests on the hidden provider module | ||
|
||
from typing import List, Tuple, TypeVar | ||
|
||
from sciline._provider import ArgSpec | ||
|
||
|
||
def test_arg_spec() -> None: | ||
arg_spec = ArgSpec(args={'a': int}, kwargs={'b': str}) | ||
assert list(arg_spec.args) == [int] | ||
assert dict(arg_spec.kwargs) == {'b': str} | ||
|
||
|
||
def combine_numbers(a: int, *, b: float) -> str: | ||
return f"{a} and {b}" | ||
|
||
|
||
T = TypeVar('T', int, float) | ||
|
||
|
||
def complicated_append(a: T, *, b: List[T]) -> Tuple[T, ...]: | ||
b.append(a) | ||
|
||
return tuple(b) | ||
|
||
|
||
def test_arg_spec_from_function_simple() -> None: | ||
arg_spec = ArgSpec.from_function(combine_numbers) | ||
assert list(arg_spec.args) == [int] | ||
assert dict(arg_spec.kwargs) == {'b': float} | ||
|
||
|
||
def test_arg_spec_from_function_typevar() -> None: | ||
arg_spec = ArgSpec.from_function(complicated_append) | ||
|
||
assert list(arg_spec.args) == [T] | ||
assert dict(arg_spec.kwargs) == {'b': List[T]} # type: ignore[valid-type] | ||
specific_arg_spec = arg_spec.bind_type_vars( | ||
bound={T: int} # type: ignore[dict-item] | ||
) | ||
assert list(specific_arg_spec.args) == [int] | ||
assert dict(specific_arg_spec.kwargs) == {'b': list[int]} | ||
|
||
|
||
def test_arg_spec_decorated_function_with_wraps() -> None: | ||
from typing import Callable, Union | ||
|
||
def decorator(func: Callable[..., str]) -> Callable[..., str]: | ||
from functools import wraps | ||
|
||
@wraps(func) | ||
def wrapper(*args: Union[int, float], **kwargs: Union[int, float]) -> str: | ||
return "Wrapped: " + func(*args, **kwargs) | ||
|
||
return wrapper | ||
|
||
@decorator | ||
def decorated(a: int, *, b: float) -> str: | ||
return f"{a} and {b}" | ||
|
||
arg_spec = ArgSpec.from_function(decorated) | ||
assert list(arg_spec.args) == [int] | ||
assert dict(arg_spec.kwargs) == {'b': float} | ||
|
||
|
||
def test_arg_spec_decorated_function_without_wraps() -> None: | ||
from typing import Callable, Union | ||
|
||
def decorator(func: Callable[..., str]) -> Callable[..., str]: | ||
def wrapper(*args: Union[int, float], **kwargs: Union[int, float]) -> str: | ||
return "Wrapped: " + func(*args, **kwargs) | ||
|
||
return wrapper | ||
|
||
@decorator | ||
def decorated(a: int, *, b: float) -> str: | ||
return f"{a} and {b}" | ||
|
||
arg_spec = ArgSpec.from_function(decorated) | ||
assert list(arg_spec.args) == [] | ||
assert dict(arg_spec.kwargs) == {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters