diff --git a/pyproject.toml b/pyproject.toml index b92ea0b8..5487b956 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,6 @@ classifiers = [ requires-python = ">=3.8" dependencies = [ "dask", - "injector", ] dynamic = ["version"] diff --git a/requirements/base.in b/requirements/base.in index 43292dbf..b2034ba3 100644 --- a/requirements/base.in +++ b/requirements/base.in @@ -1,2 +1 @@ dask -injector diff --git a/requirements/base.txt b/requirements/base.txt index 235c6275..7aeaacdd 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,22 +1,20 @@ -# SHA1:436d53b37d30cc5f2a62b6ed730bcf46f1762d7b +# SHA1:b153fb0209b1ab0c52d1b98a28b498e4dc303498 # # This file is autogenerated by pip-compile-multi # To update, run: # # pip-compile-multi # -click==8.1.3 +click==8.1.4 # via dask cloudpickle==2.2.1 # via dask dask==2023.5.0 # via -r base.in -fsspec==2023.5.0 +fsspec==2023.6.0 # via dask -importlib-metadata==6.6.0 +importlib-metadata==6.8.0 # via dask -injector==0.20.1 - # via -r base.in locket==1.0.0 # via partd packaging==23.1 @@ -29,7 +27,5 @@ toolz==0.12.0 # via # dask # partd -typing-extensions==4.6.3 - # via injector -zipp==3.15.0 +zipp==3.16.0 # via importlib-metadata diff --git a/requirements/ci.txt b/requirements/ci.txt index f17a6bc6..beb13154 100644 --- a/requirements/ci.txt +++ b/requirements/ci.txt @@ -11,19 +11,19 @@ certifi==2023.5.7 # via requests chardet==5.1.0 # via tox -charset-normalizer==3.1.0 +charset-normalizer==3.2.0 # via requests colorama==0.4.6 # via tox distlib==0.3.6 # via virtualenv -filelock==3.12.0 +filelock==3.12.2 # via # tox # virtualenv gitdb==4.0.10 # via gitpython -gitpython==3.1.31 +gitpython==3.1.32 # via -r ci.in idna==3.4 # via requests @@ -32,13 +32,13 @@ packaging==23.1 # -r ci.in # pyproject-api # tox -platformdirs==3.5.1 +platformdirs==3.8.1 # via # tox # virtualenv -pluggy==1.0.0 +pluggy==1.2.0 # via tox -pyproject-api==1.5.1 +pyproject-api==1.5.3 # via tox requests==2.31.0 # via -r ci.in @@ -48,9 +48,9 @@ tomli==2.0.1 # via # pyproject-api # tox -tox==4.6.0 +tox==4.6.4 # via -r ci.in urllib3==2.0.3 # via requests -virtualenv==20.23.0 +virtualenv==20.23.1 # via tox diff --git a/requirements/docs.txt b/requirements/docs.txt index 3b558460..0ad8c12d 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,4 +1,4 @@ -# SHA1:58bb49d5210e2885ee4c5c48167d217012b5b723 +# SHA1:c8d387b90560a2db42c26e33735a4f78ada867c3 # # This file is autogenerated by pip-compile-multi # To update, run: @@ -13,8 +13,10 @@ alabaster==0.7.13 asttokens==2.2.1 # via stack-data attrs==23.1.0 - # via jsonschema -autodoc-pydantic==1.8.0 + # via + # jsonschema + # referencing +autodoc-pydantic==1.9.0 # via -r docs.in babel==2.12.1 # via @@ -30,7 +32,7 @@ bleach==6.0.0 # via nbconvert certifi==2023.5.7 # via requests -charset-normalizer==3.1.0 +charset-normalizer==3.2.0 # via requests comm==0.1.3 # via ipykernel @@ -54,9 +56,11 @@ idna==3.4 # via requests imagesize==1.4.1 # via sphinx -importlib-resources==5.12.0 - # via jsonschema -ipykernel==6.23.1 +importlib-resources==6.0.0 + # via + # jsonschema + # jsonschema-specifications +ipykernel==6.24.0 # via -r docs.in ipython==8.12.2 # via @@ -70,13 +74,15 @@ jinja2==3.1.2 # nbconvert # nbsphinx # sphinx -jsonschema==4.17.3 +jsonschema==4.18.0 # via nbformat -jupyter-client==8.2.0 +jsonschema-specifications==2023.6.1 + # via jsonschema +jupyter-client==8.3.0 # via # ipykernel # nbclient -jupyter-core==5.3.0 +jupyter-core==5.3.1 # via # ipykernel # jupyter-client @@ -85,7 +91,7 @@ jupyter-core==5.3.0 # nbformat jupyterlab-pygments==0.2.2 # via nbconvert -markdown-it-py==2.2.0 +markdown-it-py==3.0.0 # via # mdit-py-plugins # myst-parser @@ -97,19 +103,19 @@ matplotlib-inline==0.1.6 # via # ipykernel # ipython -mdit-py-plugins==0.3.5 +mdit-py-plugins==0.4.0 # via myst-parser mdurl==0.1.2 # via markdown-it-py -mistune==2.0.5 +mistune==3.0.1 # via nbconvert -myst-parser==1.0.0 +myst-parser==2.0.0 # via -r docs.in nbclient==0.8.0 # via nbconvert -nbconvert==7.4.0 +nbconvert==7.6.0 # via nbsphinx -nbformat==5.9.0 +nbformat==5.9.1 # via # nbclient # nbconvert @@ -128,9 +134,9 @@ pickleshare==0.7.5 # via ipython pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==3.5.1 +platformdirs==3.8.1 # via jupyter-core -prompt-toolkit==3.0.38 +prompt-toolkit==3.0.39 # via ipython psutil==5.9.5 # via ipykernel @@ -138,7 +144,7 @@ ptyprocess==0.7.0 # via pexpect pure-eval==0.2.2 # via stack-data -pydantic==1.10.9 +pydantic==1.10.11 # via autodoc-pydantic pydata-sphinx-theme==0.13.3 # via -r docs.in @@ -149,8 +155,6 @@ pygments==2.15.1 # nbconvert # pydata-sphinx-theme # sphinx -pyrsistent==0.19.3 - # via jsonschema python-dateutil==2.8.2 # via jupyter-client pytz==2023.3 @@ -159,8 +163,16 @@ pyzmq==25.1.0 # via # ipykernel # jupyter-client +referencing==0.29.1 + # via + # jsonschema + # jsonschema-specifications requests==2.31.0 # via sphinx +rpds-py==0.8.10 + # via + # jsonschema + # referencing six==1.16.0 # via # asttokens @@ -218,6 +230,11 @@ traitlets==5.9.0 # nbconvert # nbformat # nbsphinx +typing-extensions==4.7.1 + # via + # ipython + # pydantic + # pydata-sphinx-theme urllib3==2.0.3 # via requests wcwidth==0.2.6 diff --git a/requirements/static.txt b/requirements/static.txt index 1c04a6da..e587652e 100644 --- a/requirements/static.txt +++ b/requirements/static.txt @@ -9,19 +9,19 @@ cfgv==3.3.1 # via pre-commit distlib==0.3.6 # via virtualenv -filelock==3.12.0 +filelock==3.12.2 # via virtualenv identify==2.5.24 # via pre-commit nodeenv==1.8.0 # via pre-commit -platformdirs==3.5.1 +platformdirs==3.8.1 # via virtualenv -pre-commit==3.3.2 +pre-commit==3.3.3 # via -r static.in pyyaml==6.0 # via pre-commit -virtualenv==20.23.0 +virtualenv==20.23.1 # via pre-commit # The following packages are considered to be unsafe in a requirements file: diff --git a/requirements/test.in b/requirements/test.in index 1cf404d7..9064cbf2 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -1,2 +1,3 @@ -r base.in +numpy pytest diff --git a/requirements/test.txt b/requirements/test.txt index 62c51aba..f589f618 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,4 +1,4 @@ -# SHA1:a035a60fcbac4cd7bf595dbd81ee7994505d4a95 +# SHA1:3720d9b18830e4fcedb827ec36f6808035c1ea2c # # This file is autogenerated by pip-compile-multi # To update, run: @@ -6,13 +6,15 @@ # pip-compile-multi # -r base.txt -exceptiongroup==1.1.1 +exceptiongroup==1.1.2 # via pytest iniconfig==2.0.0 # via pytest -pluggy==1.0.0 +numpy==1.24.4 + # via -r test.in +pluggy==1.2.0 # via pytest -pytest==7.3.1 +pytest==7.4.0 # via -r test.in tomli==2.0.1 # via pytest diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py index 85b4db57..4c52072d 100644 --- a/src/sciline/__init__.py +++ b/src/sciline/__init__.py @@ -9,4 +9,18 @@ except importlib.metadata.PackageNotFoundError: __version__ = "0.0.0" -from .container import Container, UnsatisfiedRequirement, make_container +from .container import ( + AmbiguousProvider, + Container, + UnboundTypeVar, + UnsatisfiedRequirement, +) +from .domain import Scope + +__all__ = [ + "AmbiguousProvider", + "Container", + "Scope", + "UnboundTypeVar", + "UnsatisfiedRequirement", +] diff --git a/src/sciline/container.py b/src/sciline/container.py index 60ae5341..dc1a5da5 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -1,10 +1,21 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -import typing -from functools import wraps -from typing import Callable, List, Type, TypeVar, Union - -import injector +from __future__ import annotations + +from typing import ( + Any, + Callable, + Dict, + List, + Tuple, + Type, + TypeVar, + get_args, + get_origin, + get_type_hints, +) + +import dask from dask.delayed import Delayed T = TypeVar('T') @@ -14,81 +25,140 @@ class UnsatisfiedRequirement(Exception): pass -class Container: - def __init__(self, inj: injector.Injector, /, *, lazy: bool) -> None: - self._injector = inj - self._lazy = lazy - - def get(self, tp: Type[T], /) -> Union[T, Delayed]: - try: - # We are slightly abusing Python's type system here, by using the - # injector to get T, but actually it returns a Delayed that can - # compute T. self._injector does not know this due to how we setup the - # bindings. We'd like to use Delayed[T], but that is not supported yet: - # https://github.com/dask/dask/pull/9256 - task: Delayed = self._injector.get(tp) # type: ignore - except injector.UnsatisfiedRequirement as e: - raise UnsatisfiedRequirement(e) from e - return task if self._lazy else task.compute() - - -def _delayed(func: Callable) -> Callable: +class UnboundTypeVar(Exception): + pass + + +class AmbiguousProvider(Exception): + pass + + +def _is_compatible_type_tuple( + requested: tuple[type, ...], provided: tuple[type | TypeVar, ...] +) -> bool: """ - Decorator to make a function return a delayed object. + Check if a tuple of requested types is compatible with a tuple of provided types. - In contrast to dask.delayed, this uses functools.wraps, to preserve the - type hints, which is a prerequisite for injector to work. + Types in the tuples must either by equal, or the provided type must be a TypeVar. """ - import dask + for req, prov in zip(requested, provided): + if isinstance(prov, TypeVar): + continue + if req != prov: + return False + return True - @wraps(func) - def wrapper(*args, **kwargs): - return dask.delayed(func)(*args, **kwargs) - return wrapper +def _bind_free_typevars(tp: TypeVar | type, bound: Dict[TypeVar, type]) -> type: + if isinstance(tp, TypeVar): + if (result := bound.get(tp)) is None: + raise UnboundTypeVar(f'Unbound type variable {tp}') + return result + elif (origin := get_origin(tp)) is not None: + result = origin[tuple(_bind_free_typevars(arg, bound) for arg in get_args(tp))] + if result is None: + raise ValueError(f'Binding type variables in {tp} resulted in `None`') + return result + else: + return tp -def _injectable(func: Callable) -> Callable: - """ - Wrap a regular function so it can be registered in an injector and have its - parameters injected. - """ - # When building a workflow, there are two common problems: - # - # 1. Intermediate results are used more than once. - # 2. Intermediate results are large, so we generally do not want to keep them - # in memory longer than necessary. - # - # To address these problems, we can internally build a graph of tasks, instead of - # directly creating dependencies between functions. Currently we use Dask for this. - # The Container instance will automatically compute the task, unless it is marked - # as lazy. We therefore use singleton-scope (to ensure Dask will recognize the - # task as the same object) and also wrap the function in dask.delayed. - scope = injector.singleton - func = _delayed(func) - tps = typing.get_type_hints(func) - - def bind(binder: injector.Binder): - binder.bind(tps['return'], injector.inject(func), scope=scope) - - return bind - - -def make_container(funcs: List[Callable], /, *, lazy: bool = False) -> Container: - """ - Create a :py:class:`Container` from a list of functions. - - Parameters - ---------- - funcs: - List of functions to be injected. Must be annotated with type hints. - lazy: - If True, the functions are wrapped in :py:func:`dask.delayed` before - being injected. This allows to build a dask graph from the container. - """ - # Note that we disable auto_bind, to ensure we do not accidentally bind to - # some default values. Everything must be explicit. - return Container( - injector.Injector([_injectable(f) for f in funcs], auto_bind=False), - lazy=lazy, - ) +Provider = Callable[..., Any] + + +class Container: + def __init__(self, funcs: List[Provider], /): + """ + Create a :py:class:`Container` from a list of functions. + + Parameters + ---------- + funcs: + List of functions to be injected. Must be annotated with type hints. + """ + self._providers: Dict[type, Provider] = {} + self._subproviders: Dict[type, Dict[Tuple[type | TypeVar, ...], Provider]] = {} + self._cache: Dict[type, Delayed] = {} + for func in funcs: + self.insert(func) + + def insert(self, provider: Provider) -> None: + if (key := get_type_hints(provider).get('return')) is None: + raise ValueError(f'Provider {provider} lacks type-hint for return value') + # isinstance does not work here and types.NoneType available only in 3.10+ + if key == type(None): # noqa: E721 + raise ValueError(f'Provider {provider} returning `None` is not allowed') + if (origin := get_origin(key)) is not None: + subproviders = self._subproviders.setdefault(origin, {}) + args = get_args(key) + if args in subproviders: + raise ValueError(f'Provider for {key} already exists') + subproviders[args] = provider + else: + if key in self._providers: + raise ValueError(f'Provider for {key} already exists') + self._providers[key] = provider + + def _get_args( + self, func: Callable[..., Any], bound: Dict[TypeVar, type] + ) -> Dict[str, Delayed]: + return { + name: self.get(_bind_free_typevars(tp, bound=bound)) + for name, tp in get_type_hints(func).items() + if name != 'return' + } + + def _get_provider( + self, tp: Type[T] + ) -> Tuple[Callable[..., T], Dict[TypeVar, type]]: + if (provider := self._providers.get(tp)) is not None: + return provider, {} + elif (origin := get_origin(tp)) is not None and ( + subproviders := self._subproviders.get(origin) + ) is not None: + requested = get_args(tp) + matches = [ + (args, subprovider) + for args, subprovider in subproviders.items() + if _is_compatible_type_tuple(requested, args) + ] + if len(matches) == 1: + args, provider = matches[0] + bound = { + arg: req + for arg, req in zip(args, requested) + if isinstance(arg, TypeVar) + } + return provider, bound + elif len(matches) > 1: + raise AmbiguousProvider("Multiple providers found for type", tp) + raise UnsatisfiedRequirement("No provider found for type", tp) + + def get(self, tp: Type[T], /) -> Delayed: + # We are slightly abusing Python's type system here, by using the + # self.get to get T, but actually it returns a Delayed that can + # compute T. We'd like to use Delayed[T], but that is not supported yet: + # https://github.com/dask/dask/pull/9256 + # + # When building a workflow, there are two common problems: + # + # 1. Intermediate results are used more than once. + # 2. Intermediate results are large, so we generally do not want to keep them + # in memory longer than necessary. + # + # To address these problems, we can internally build a graph of tasks, instead + # of directly creating dependencies between functions. Currently we use Dask + # for this. We cache call results to ensure Dask will recognize the task + # as the same object) and also wrap the function in dask.delayed. + if (cached := self._cache.get(tp)) is not None: + return cached + + provider, bound = self._get_provider(tp) + args = self._get_args(provider, bound=bound) + delayed = dask.delayed(provider) # type: ignore[attr-defined] + return self._cache.setdefault(tp, delayed(**args)) + + def compute(self, tp: Type[T], /) -> T: + task = self.get(tp) + result: T = task.compute() # type: ignore[no-untyped-call] + return result diff --git a/src/sciline/domain.py b/src/sciline/domain.py new file mode 100644 index 00000000..4117c26f --- /dev/null +++ b/src/sciline/domain.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from typing import Any, Generic, TypeVar + +T = TypeVar("T") + + +class Scope(Generic[T]): + def __new__(cls, x) -> Any: # type: ignore[no-untyped-def] + return x diff --git a/tests/complex_workflow_test.py b/tests/complex_workflow_test.py new file mode 100644 index 00000000..c576f7d6 --- /dev/null +++ b/tests/complex_workflow_test.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from dataclasses import dataclass +from typing import NewType, TypeVar + +import dask +import numpy as np +import numpy.typing as npt + +import sciline as sl + +# We use dask with a single thread, to ensure that call counting below is correct. +dask.config.set(scheduler='synchronous') + + +@dataclass +class RawData: + data: npt.NDArray[np.float64] + monitor1: float + monitor2: float + + +SampleRun = NewType('SampleRun', int) +BackgroundRun = NewType('BackgroundRun', int) +DetectorMask = NewType('DetectorMask', npt.NDArray[np.float64]) +DirectBeam = NewType('DirectBeam', npt.NDArray[np.float64]) +SolidAngle = NewType('SolidAngle', npt.NDArray[np.float64]) + +Run = TypeVar('Run') + + +class Raw(sl.Scope[Run], RawData): + ... + + +class Masked(sl.Scope[Run], npt.NDArray[np.float64]): + ... + + +class IncidentMonitor(sl.Scope[Run], float): + ... + + +class TransmissionMonitor(sl.Scope[Run], float): + ... + + +class TransmissionFraction(sl.Scope[Run], float): + ... + + +class IofQ(sl.Scope[Run], npt.NDArray[np.float64]): + ... + + +BackgroundSubtractedIofQ = NewType('BackgroundSubtractedIofQ', npt.NDArray[np.float64]) + + +def incident_monitor(x: Raw[Run]) -> IncidentMonitor[Run]: + return IncidentMonitor(x.monitor1) + + +def transmission_monitor(x: Raw[Run]) -> TransmissionMonitor[Run]: + return TransmissionMonitor(x.monitor2) + + +def mask_detector(x: Raw[Run], mask: DetectorMask) -> Masked[Run]: + return Masked(x.data * mask) + + +def transmission( + incident: IncidentMonitor[Run], transmission: TransmissionMonitor[Run] +) -> TransmissionFraction[Run]: + return TransmissionFraction(incident / transmission) + + +def iofq( + x: Masked[Run], + solid_angle: SolidAngle, + direct_beam: DirectBeam, + transmission: TransmissionFraction[Run], +) -> IofQ[Run]: + return IofQ(x / (solid_angle * direct_beam * transmission)) + + +reduction = [incident_monitor, transmission_monitor, mask_detector, transmission, iofq] + + +def raw_sample() -> Raw[SampleRun]: + return Raw(RawData(data=np.ones(4), monitor1=1.0, monitor2=2.0)) + + +def raw_background() -> Raw[BackgroundRun]: + return Raw(RawData(data=np.ones(4) * 1.5, monitor1=1.0, monitor2=4.0)) + + +def detector_mask() -> DetectorMask: + return DetectorMask(np.array([1, 1, 0, 1])) + + +def solid_angle() -> SolidAngle: + return SolidAngle(np.array([1.0, 0.5, 0.25, 0.125])) + + +def direct_beam() -> DirectBeam: + return DirectBeam(np.array(1 / 1.5)) + + +def subtract_background( + sample: IofQ[SampleRun], background: IofQ[BackgroundRun] +) -> BackgroundSubtractedIofQ: + return BackgroundSubtractedIofQ(sample - background) + + +def test_reduction_workflow() -> None: + # See https://github.com/python/mypy/issues/14661 + container = sl.Container( + [ + raw_sample, + raw_background, + detector_mask, + solid_angle, + direct_beam, + subtract_background, + ] + + reduction # type: ignore[arg-type] + ) + + assert np.array_equal(container.compute(IofQ[SampleRun]), [3, 6, 0, 24]) + assert np.array_equal(container.compute(IofQ[BackgroundRun]), [9, 18, 0, 72]) + assert np.array_equal( + container.compute(BackgroundSubtractedIofQ), [-6, -12, 0, -48] + ) diff --git a/tests/container_test.py b/tests/container_test.py index 09fc2896..82822d8a 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from dataclasses import dataclass +from typing import Generic, List, NewType, TypeVar + import dask import pytest @@ -9,31 +12,31 @@ dask.config.set(scheduler='synchronous') -def f(x: int) -> float: +def int_to_float(x: int) -> float: return 0.5 * x -def g() -> int: +def make_int() -> int: return 3 -def h(x: int, y: float) -> str: +def int_float_to_str(x: int, y: float) -> str: return f"{x};{y}" -def test_make_container_sets_up_working_container(): - container = sl.make_container([f, g]) - assert container.get(float) == 1.5 - assert container.get(int) == 3 +def test_make_container_sets_up_working_container() -> None: + container = sl.Container([int_to_float, make_int]) + assert container.compute(float) == 1.5 + assert container.compute(int) == 3 -def test_make_container_does_not_autobind(): - container = sl.make_container([f]) +def test_make_container_does_not_autobind() -> None: + container = sl.Container([int_to_float]) with pytest.raises(sl.UnsatisfiedRequirement): - container.get(float) + container.compute(float) -def test_intermediate_computed_once_when_not_lazy(): +def test_intermediate_computed_once() -> None: ncall = 0 def provide_int() -> int: @@ -41,19 +44,19 @@ def provide_int() -> int: ncall += 1 return 3 - container = sl.make_container([f, provide_int, h], lazy=False) - assert container.get(str) == "3;1.5" + container = sl.Container([int_to_float, provide_int, int_float_to_str]) + assert container.compute(str) == "3;1.5" assert ncall == 1 -def test_make_container_lazy_returns_task_that_computes_result(): - container = sl.make_container([f, g], lazy=True) +def test_get_returns_task_that_computes_result() -> None: + container = sl.Container([int_to_float, make_int]) task = container.get(float) assert hasattr(task, 'compute') - assert task.compute() == 1.5 + assert task.compute() == 1.5 # type: ignore[no-untyped-call] -def test_lazy_with_multiple_outputs_computes_intermediates_once(): +def test_multiple_get_calls_can_be_computed_without_repeated_calls() -> None: ncall = 0 def provide_int() -> int: @@ -61,8 +64,327 @@ def provide_int() -> int: ncall += 1 return 3 - container = sl.make_container([f, provide_int, h], lazy=True) + container = sl.Container([int_to_float, provide_int, int_float_to_str]) task1 = container.get(float) task2 = container.get(str) - assert dask.compute(task1, task2) == (1.5, '3;1.5') + assert dask.compute(task1, task2) == (1.5, '3;1.5') # type: ignore[attr-defined] + assert ncall == 1 + + +def test_make_container_with_subgraph_template() -> None: + ncall = 0 + + def provide_int() -> int: + nonlocal ncall + ncall += 1 + return 3 + + Param = TypeVar('Param') + + class Float(sl.Scope[Param], float): + ... + + class Str(sl.Scope[Param], str): + ... + + def int_float_to_str(x: int, y: Float[Param]) -> Str[Param]: + return Str(f"{x};{y}") + + Run1 = NewType('Run1', int) + Run2 = NewType('Run2', int) + Result = NewType('Result', str) + + def float1() -> Float[Run1]: + return Float[Run1](1.5) + + def float2() -> Float[Run2]: + return Float[Run2](2.5) + + def use_strings(s1: Str[Run1], s2: Str[Run2]) -> Result: + return Result(f"{s1};{s2}") + + container = sl.Container( + [provide_int, float1, float2, use_strings, int_float_to_str], + ) + task = container.get(Result) + assert task.compute() == "3;1.5;3;2.5" # type: ignore[no-untyped-call] assert ncall == 1 + + +Param = TypeVar('Param') + + +class Str(sl.Scope[Param], str): + ... + + +def f(x: Param) -> Str[Param]: + return Str(f'{x}') + + +def test_container_from_templated() -> None: + def make_float() -> float: + return 1.5 + + def combine(x: Str[int], y: Str[float]) -> str: + return f"{x};{y}" + + container = sl.Container([make_int, make_float, combine, f]) + assert container.compute(Str[int]) == '3' + assert container.compute(Str[float]) == '1.5' + assert container.compute(str) == '3;1.5' + + +def test_inserting_provider_returning_None_raises() -> None: + def provide_none() -> None: + return None + + with pytest.raises(ValueError): + sl.Container([provide_none]) + container = sl.Container([]) + with pytest.raises(ValueError): + container.insert(provide_none) + + +def test_inserting_provider_with_no_return_type_raises() -> None: + def provide_none(): # type: ignore[no-untyped-def] + return None + + with pytest.raises(ValueError): + sl.Container([provide_none]) + container = sl.Container([]) + with pytest.raises(ValueError): + container.insert(provide_none) + + +def test_typevar_requirement_of_provider_can_be_bound() -> None: + T = TypeVar('T') + + def provider_int() -> int: + return 3 + + def provider(x: T) -> List[T]: + return [x, x] + + container = sl.Container([provider_int, provider]) + assert container.compute(List[int]) == [3, 3] + + +def test_typevar_that_cannot_be_bound_raises_UnboundTypeVar() -> None: + T = TypeVar('T') + + def provider(_: T) -> int: + return 1 + + container = sl.Container([provider]) + with pytest.raises(sl.UnboundTypeVar): + container.compute(int) + + +def test_unsatisfiable_typevar_requirement_of_provider_raises() -> None: + T = TypeVar('T') + + def provider_int() -> int: + return 3 + + def provider(x: T) -> List[T]: + return [x, x] + + container = sl.Container([provider_int, provider]) + with pytest.raises(sl.UnsatisfiedRequirement): + container.compute(List[float]) + + +def test_TypeVar_params_are_not_associated_unless_they_match() -> None: + T1 = TypeVar('T1') + T2 = TypeVar('T2') + + class A(Generic[T1]): + ... + + class B(Generic[T2]): + ... + + def source() -> A[int]: + return A[int]() + + def not_matching(x: A[T1]) -> B[T2]: + return B[T2]() + + def matching(x: A[T1]) -> B[T1]: + return B[T1]() + + container = sl.Container([source, not_matching]) + with pytest.raises(sl.UnboundTypeVar): + container.compute(B[int]) + + container = sl.Container([source, matching]) + container.compute(B[int]) + + +def test_multi_Generic_with_fully_bound_arguments() -> None: + T1 = TypeVar('T1') + T2 = TypeVar('T2') + + @dataclass + class A(Generic[T1, T2]): + first: T1 + second: T2 + + def source() -> A[int, float]: + return A[int, float](1, 2.0) + + container = sl.Container([source]) + assert container.compute(A[int, float]) == A[int, float](1, 2.0) + + +def test_multi_Generic_with_partially_bound_arguments() -> None: + T1 = TypeVar('T1') + T2 = TypeVar('T2') + + @dataclass + class A(Generic[T1, T2]): + first: T1 + second: T2 + + def source() -> float: + return 2.0 + + def partially_bound(x: T1) -> A[int, T1]: + return A[int, T1](1, x) + + container = sl.Container([source, partially_bound]) + assert container.compute(A[int, float]) == A[int, float](1, 2.0) + + +def test_multi_Generic_with_multiple_unbound() -> None: + T1 = TypeVar('T1') + T2 = TypeVar('T2') + + @dataclass + class A(Generic[T1, T2]): + first: T1 + second: T2 + + def int_source() -> int: + return 1 + + def float_source() -> float: + return 2.0 + + def unbound(x: T1, y: T2) -> A[T1, T2]: + return A[T1, T2](x, y) + + container = sl.Container([int_source, float_source, unbound]) + assert container.compute(A[int, float]) == A[int, float](1, 2.0) + assert container.compute(A[float, int]) == A[float, int](2.0, 1) + + +def test_distinct_fully_bound_instances_yield_distinct_results() -> None: + T1 = TypeVar('T1') + + @dataclass + class A(Generic[T1]): + value: T1 + + def int_source() -> A[int]: + return A[int](1) + + def float_source() -> A[float]: + return A[float](2.0) + + container = sl.Container([int_source, float_source]) + assert container.compute(A[int]) == A[int](1) + assert container.compute(A[float]) == A[float](2.0) + + +def test_distinct_partially_bound_instances_yield_distinct_results() -> None: + T1 = TypeVar('T1') + T2 = TypeVar('T2') + + @dataclass + class A(Generic[T1, T2]): + first: T1 + second: T2 + + def str_source() -> str: + return 'a' + + def int_source(x: T1) -> A[int, T1]: + return A[int, T1](1, x) + + def float_source(x: T1) -> A[float, T1]: + return A[float, T1](2.0, x) + + container = sl.Container([str_source, int_source, float_source]) + assert container.compute(A[int, str]) == A[int, str](1, 'a') + assert container.compute(A[float, str]) == A[float, str](2.0, 'a') + + +def test_multiple_matching_partial_providers_raises() -> None: + T1 = TypeVar('T1') + T2 = TypeVar('T2') + + @dataclass + class A(Generic[T1, T2]): + first: T1 + second: T2 + + def int_source() -> int: + return 1 + + def float_source() -> float: + return 2.0 + + def provider1(x: T1) -> A[int, T1]: + return A[int, T1](1, x) + + def provider2(x: T2) -> A[T2, float]: + return A[T2, float](x, 2.0) + + container = sl.Container([int_source, float_source, provider1, provider2]) + assert container.compute(A[int, int]) == A[int, int](1, 1) + assert container.compute(A[float, float]) == A[float, float](2.0, 2.0) + with pytest.raises(sl.AmbiguousProvider): + container.compute(A[int, float]) + + +def test_TypeVar_params_track_to_multiple_sources() -> None: + T1 = TypeVar('T1') + T2 = TypeVar('T2') + + @dataclass + class A(Generic[T1]): + value: T1 + + @dataclass + class B(Generic[T1]): + value: T1 + + @dataclass + class C(Generic[T1, T2]): + first: T1 + second: T2 + + def provide_int() -> int: + return 1 + + def provide_float() -> float: + return 2.0 + + def provide_A(x: T1) -> A[T1]: + return A[T1](x) + + # Note that it currently does not matter which TypeVar instance we use here: + # Container tracks uses of TypeVar within a single provider, but does not carry + # the information beyond the scope of a single call. + def provide_B(x: T1) -> B[T1]: + return B[T1](x) + + def provide_C(x: A[T1], y: B[T2]) -> C[T1, T2]: + return C[T1, T2](x.value, y.value) + + container = sl.Container( + [provide_int, provide_float, provide_A, provide_B, provide_C] + ) + assert container.compute(C[int, float]) == C[int, float](1, 2.0) diff --git a/tests/package_test.py b/tests/package_test.py index 426e378e..62ceb639 100644 --- a/tests/package_test.py +++ b/tests/package_test.py @@ -3,5 +3,5 @@ import sciline as pkg -def test_has_version(): +def test_has_version() -> None: assert hasattr(pkg, '__version__')