From d52016fc11f1ad4600cada39e1c1318266527218 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 9 Jun 2023 13:29:54 +0200 Subject: [PATCH 01/43] Experiment with child container mechanism --- src/sciline/container.py | 27 +++++++ tests/container_test.py | 170 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 189 insertions(+), 8 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 60ae5341..6a49f8f0 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + import typing from functools import wraps from typing import Callable, List, Type, TypeVar, Union @@ -31,6 +33,31 @@ def get(self, tp: Type[T], /) -> Union[T, Delayed]: raise UnsatisfiedRequirement(e) from e return task if self._lazy else task.compute() + def make_child_container(self, funcs: List[Callable], /) -> Container: + """ + Create a child container from a list of functions. + + The child container inherits all bindings from the parent container, but + can override them with new bindings. + + Warning + ------- + + Note that it is not possible to override transitive dependencies, i.e., if the + parent container provides A, and A depends on B, then the child container + cannot override the B that is used by A. It can only override the B that is + used by the child container. + + Parameters + ---------- + funcs: + List of functions to be injected. Must be annotated with type hints. + """ + return Container( + self._injector.create_child_injector([_injectable(f) for f in funcs]), + lazy=self._lazy, + ) + def _delayed(func: Callable) -> Callable: """ diff --git a/tests/container_test.py b/tests/container_test.py index 09fc2896..6e9bd11a 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from typing import NewType + import dask import pytest @@ -9,26 +11,26 @@ dask.config.set(scheduler='synchronous') -def f(x: int) -> float: +def int_to_float(x: int) -> float: return 0.5 * x -def g() -> int: +def make_int() -> int: return 3 -def h(x: int, y: float) -> str: +def int_float_to_str(x: int, y: float) -> str: return f"{x};{y}" def test_make_container_sets_up_working_container(): - container = sl.make_container([f, g]) + container = sl.make_container([int_to_float, make_int]) assert container.get(float) == 1.5 assert container.get(int) == 3 def test_make_container_does_not_autobind(): - container = sl.make_container([f]) + container = sl.make_container([int_to_float]) with pytest.raises(sl.UnsatisfiedRequirement): container.get(float) @@ -41,13 +43,15 @@ def provide_int() -> int: ncall += 1 return 3 - container = sl.make_container([f, provide_int, h], lazy=False) + container = sl.make_container( + [int_to_float, provide_int, int_float_to_str], lazy=False + ) assert container.get(str) == "3;1.5" assert ncall == 1 def test_make_container_lazy_returns_task_that_computes_result(): - container = sl.make_container([f, g], lazy=True) + container = sl.make_container([int_to_float, make_int], lazy=True) task = container.get(float) assert hasattr(task, 'compute') assert task.compute() == 1.5 @@ -61,8 +65,158 @@ def provide_int() -> int: ncall += 1 return 3 - container = sl.make_container([f, provide_int, h], lazy=True) + container = sl.make_container( + [int_to_float, provide_int, int_float_to_str], lazy=True + ) task1 = container.get(float) task2 = container.get(str) assert dask.compute(task1, task2) == (1.5, '3;1.5') assert ncall == 1 + + +def test_make_child_container_inherits_bindings_from_parent(): + container = sl.make_container([int_to_float, make_int]) + child = container.make_child_container([int_float_to_str]) + assert child.get(str) == "3;1.5" + + +def test_make_child_container_override_parent_binding(): + def other_int() -> int: + return 4 + + container = sl.make_container([make_int]) + child = container.make_child_container([other_int, int_to_float, int_float_to_str]) + assert child.get(str) == "4;2.0" + assert child.get(int) == 4 + + +def test_make_child_container_override_does_not_affect_transitive_dependency(): + def other_int() -> int: + return 4 + + container = sl.make_container([int_to_float, make_int]) + child = container.make_child_container([other_int, int_float_to_str]) + assert child.get(int) == 4 + # Float computed within parent from int=3, not from other_int=4 + assert child.get(str) == "4;1.5" + + +def test_make_child_container_can_provide_transitive_dependency(): + # child injector must contain copy, we cannot use parent injector + # to provide "template" for child injector + # should we support overriding at all? + + # We had the following example in the design document: + + # import injector + # from typing import NewType + # + # BackgroundContainer = NewType('BackgroundContainer', injector.Injector) + # + # class BackgroundModule(injector.Module): + # def __init__(self, container: injector.Injector): + # self._container = container + # + # @injector.provide + # def get_background_iofq(self) -> BackgroundIofQ: + # return BackgroundIofQ(self._container.get(IofQ)) + # + # container = injector.Injector() # Configure with common reduction modules + # background = container.create_child_injector(background_config) + # background_module = BackgroundModule(background) + # container.binder.install(background_module) + + # However, it seems like this would not work? We need to create the child injector + # with all the relevant function, the parent may only provide inputs to those funcs. + # Should `Container` support a registry of templates that can be used in children? + # Or do we require manual handling? + + container = sl.make_container([int_to_float]) + child = container.make_child_container([make_int]) + assert child.get(float) == 1.5 + + +def test_make_container_with_callable_that_uses_child(): + parent = sl.make_container([int_to_float, make_int]) + child = parent.make_child_container([int_float_to_str]) + + MyStr = NewType('MyStr', str) + + def use_child() -> MyStr: + return MyStr(child.get(str)) + + container = sl.make_container([use_child]) + + assert container.get(MyStr) == "3;1.5" + + +def test_make_container_with_multiple_children(): + parent = sl.make_container([make_int]) + + def float1() -> float: + return 1.5 + + def float2() -> float: + return 2.5 + + child1 = parent.make_child_container([float1, int_float_to_str]) + child2 = parent.make_child_container([float2, int_float_to_str]) + Str1 = NewType('Str1', str) + Str2 = NewType('Str2', str) + + def get_str1() -> Str1: + return Str1(child1.get(str)) + + def get_str2() -> Str2: + return Str2(child2.get(str)) + + def use_strings(s1: Str1, s2: Str2) -> str: + return f"{s1};{s2}" + + container = sl.make_container([get_str1, get_str2, use_strings]) + assert container.get(Str1) == "3;1.5" + assert container.get(Str2) == "3;2.5" + assert container.get(str) == "3;1.5;3;2.5" + + +def test_make_container_with_multiple_children_does_not_repeat_calls(): + ncall = 0 + + def provide_int() -> int: + nonlocal ncall + ncall += 1 + return 3 + + parent = sl.make_container([provide_int], lazy=True) + + def float1() -> float: + return 1.5 + + def float2() -> float: + return 2.5 + + child1 = parent.make_child_container([float1, int_float_to_str]) + child2 = parent.make_child_container([float2, int_float_to_str]) + Str1 = NewType('Str1', str) + Str2 = NewType('Str2', str) + + def get_str1() -> Str1: + # Would need to call compute() here, but then we would not get a single graph + # Otherwise we delay a Delayed + # If we use lazy=False, then this computes too early. + return Str1(child1.get(str)) + + def get_str2() -> Str2: + # Only works by coincidence, since Str2(Delayed) works. If we used + # a function that does not take a Delayed, this would fail, e.g., sc.sin. + return Str2(child2.get(str)) + + def use_strings(s1: Str1, s2: Str2) -> str: + return f"{s1};{s2}" + + container = sl.make_container([use_strings], lazy=True) + # If we wrap with _injectable, things won't work. Need special handling. How? + container._injector.binder.bind(Str1, get_str1) + container._injector.binder.bind(Str2, get_str2) + assert container.get(str).compute() == "3;1.5;3;2.5" + assert ncall == 1 From ab65e6dd4dfcdae6e9ae740de2d3f94155603ab5 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Thu, 6 Jul 2023 13:17:25 +0200 Subject: [PATCH 02/43] Experiments --- src/sciline/__init__.py | 1 + src/sciline/container.py | 19 +++++++- src/sciline/domain.py | 22 ++++++++++ tests/container_test.py | 94 +++++++++++++++++++++++++++++++++++++++- 4 files changed, 134 insertions(+), 2 deletions(-) create mode 100644 src/sciline/domain.py diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py index 85b4db57..a8a2dbb5 100644 --- a/src/sciline/__init__.py +++ b/src/sciline/__init__.py @@ -10,3 +10,4 @@ __version__ = "0.0.0" from .container import Container, UnsatisfiedRequirement, make_container +from .domain import domain_type, parametrized_domain_type diff --git a/src/sciline/container.py b/src/sciline/container.py index 6a49f8f0..19ed9694 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -4,7 +4,7 @@ import typing from functools import wraps -from typing import Callable, List, Type, TypeVar, Union +from typing import Callable, Generic, List, Type, TypeVar, Union import injector from dask.delayed import Delayed @@ -119,3 +119,20 @@ def make_container(funcs: List[Callable], /, *, lazy: bool = False) -> Container injector.Injector([_injectable(f) for f in funcs], auto_bind=False), lazy=lazy, ) + + +def specialize(func: Callable, T: Type) -> Callable: + """ + Given a function + f(a: A, b: B, c: C, ...) -> R + return a new function + f(a: A[T], b: B[T], c: C[T], ...) -> R[T] + where T is a type variable that is bound to the type of the return value of f. + """ + tps = typing.get_type_hints(func) + + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + wrapper.__annotations__ = {k: tp[T] for k, tp in tps.items()} + return wrapper diff --git a/src/sciline/domain.py b/src/sciline/domain.py new file mode 100644 index 00000000..fd74485a --- /dev/null +++ b/src/sciline/domain.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +import typing +from typing import Callable, Generic, List, Type, TypeVar, Union + + +def domain_type(name: str, base: type) -> type: + class tp(base): + pass + + return tp + + +def parametrized_domain_type(name: str, base: type) -> type: + T = TypeVar('T') + + class tp(base, Generic[T]): + pass + + return tp diff --git a/tests/container_test.py b/tests/container_test.py index 6e9bd11a..852c4fa4 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import NewType +from typing import Generic, NewType, TypeVar import dask import pytest @@ -220,3 +220,95 @@ def use_strings(s1: Str1, s2: Str2) -> str: container._injector.binder.bind(Str2, get_str2) assert container.get(str).compute() == "3;1.5;3;2.5" assert ncall == 1 + + +def test_specialize(): + import typing + from typing import Any, List + + def f(x: List[Any]) -> List: + return x + + # specialized = sl.container.specialize(f, int) + # assert typing.get_type_hints(specialized) == {'return': List[int], 'x': List[int]} + print(typing.get_type_hints(f)) + assert False + + +T = TypeVar('T') + + +class A(Generic[T]): + def __init__(self, x: T) -> None: + self.x = x + + +import typing +from typing import Dict, List + + +def newtypes(name: str, base: type, ts: List[type]) -> Dict[type, type]: + return {t: typing.NewType(f'{name}_{t.__name__}', base) for t in ts} + + +def test_newtypes(): + Str = newtypes('Str', str, [int, float]) + s = Str[int]('3') + assert s == '3' + + def str_int() -> Str[int]: + return Str[int]('abc') + + container = sl.make_container([str_int]) + assert container.get(Str[int]) == 'abc' + + +from typing import Any, Callable + + +def test_templated_injector(): + providers = {} + provider_templates = {} + providers[int] = lambda: 3 + providers[float] = lambda: 3.5 + # Str = newtypes('Str', str, [int, float]) + # dict is not hashable, it we would not pass mypy... + Str = typing.NewType('Str', str) + provider_templates[Str] = lambda t: lambda: str(t()) + + def call(func: Callable) -> Any: + types = typing.get_type_hints(func) + args = {name: providers[tp]() for name, tp in types.items() if name != 'return'} + return func(**args) + + def f(x: int) -> str: + return str(x) + + assert call(f) == '3' + + +Str = sl.parametrized_domain_type('Str', str) + + +def templated(tp: type) -> List[Callable]: + # Could also provide option for a list of types + # How can a user extend this? Just make there own wrapper function? + def f(x: tp) -> Str[tp]: + return Str[tp](x) + + return [f] + + +def test_container_from_templated(): + def make_float() -> float: + return 1.5 + + def combine(x: Str[int], y: Str[float]) -> str: + return f"{x};{y}" + + container = sl.make_container( + [make_int, make_float, combine] + templated(int) + templated(float) + ) + assert container.get(Str[int]) == '3' + assert container.get(Str[float]) == '1.5' + assert container.get(str) == '3;1.5' From f0f1941b046836918e69e40f5fdfb8836870bbb5 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Thu, 6 Jul 2023 13:18:54 +0200 Subject: [PATCH 03/43] Remove dead-end ideas --- src/sciline/container.py | 19 +----------- tests/container_test.py | 67 +--------------------------------------- 2 files changed, 2 insertions(+), 84 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 19ed9694..6a49f8f0 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -4,7 +4,7 @@ import typing from functools import wraps -from typing import Callable, Generic, List, Type, TypeVar, Union +from typing import Callable, List, Type, TypeVar, Union import injector from dask.delayed import Delayed @@ -119,20 +119,3 @@ def make_container(funcs: List[Callable], /, *, lazy: bool = False) -> Container injector.Injector([_injectable(f) for f in funcs], auto_bind=False), lazy=lazy, ) - - -def specialize(func: Callable, T: Type) -> Callable: - """ - Given a function - f(a: A, b: B, c: C, ...) -> R - return a new function - f(a: A[T], b: B[T], c: C[T], ...) -> R[T] - where T is a type variable that is bound to the type of the return value of f. - """ - tps = typing.get_type_hints(func) - - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - - wrapper.__annotations__ = {k: tp[T] for k, tp in tps.items()} - return wrapper diff --git a/tests/container_test.py b/tests/container_test.py index 852c4fa4..5167731a 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import Generic, NewType, TypeVar +from typing import Callable, List, NewType import dask import pytest @@ -222,71 +222,6 @@ def use_strings(s1: Str1, s2: Str2) -> str: assert ncall == 1 -def test_specialize(): - import typing - from typing import Any, List - - def f(x: List[Any]) -> List: - return x - - # specialized = sl.container.specialize(f, int) - # assert typing.get_type_hints(specialized) == {'return': List[int], 'x': List[int]} - print(typing.get_type_hints(f)) - assert False - - -T = TypeVar('T') - - -class A(Generic[T]): - def __init__(self, x: T) -> None: - self.x = x - - -import typing -from typing import Dict, List - - -def newtypes(name: str, base: type, ts: List[type]) -> Dict[type, type]: - return {t: typing.NewType(f'{name}_{t.__name__}', base) for t in ts} - - -def test_newtypes(): - Str = newtypes('Str', str, [int, float]) - s = Str[int]('3') - assert s == '3' - - def str_int() -> Str[int]: - return Str[int]('abc') - - container = sl.make_container([str_int]) - assert container.get(Str[int]) == 'abc' - - -from typing import Any, Callable - - -def test_templated_injector(): - providers = {} - provider_templates = {} - providers[int] = lambda: 3 - providers[float] = lambda: 3.5 - # Str = newtypes('Str', str, [int, float]) - # dict is not hashable, it we would not pass mypy... - Str = typing.NewType('Str', str) - provider_templates[Str] = lambda t: lambda: str(t()) - - def call(func: Callable) -> Any: - types = typing.get_type_hints(func) - args = {name: providers[tp]() for name, tp in types.items() if name != 'return'} - return func(**args) - - def f(x: int) -> str: - return str(x) - - assert call(f) == '3' - - Str = sl.parametrized_domain_type('Str', str) From 1a960eb6a2f1d2425e1c850fe242c449afbe332a Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 7 Jul 2023 09:36:17 +0200 Subject: [PATCH 04/43] Remove 'make_child_container' --- src/sciline/container.py | 25 ------ tests/container_test.py | 176 ++++++++------------------------------- 2 files changed, 37 insertions(+), 164 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 6a49f8f0..4253be73 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -33,31 +33,6 @@ def get(self, tp: Type[T], /) -> Union[T, Delayed]: raise UnsatisfiedRequirement(e) from e return task if self._lazy else task.compute() - def make_child_container(self, funcs: List[Callable], /) -> Container: - """ - Create a child container from a list of functions. - - The child container inherits all bindings from the parent container, but - can override them with new bindings. - - Warning - ------- - - Note that it is not possible to override transitive dependencies, i.e., if the - parent container provides A, and A depends on B, then the child container - cannot override the B that is used by A. It can only override the B that is - used by the child container. - - Parameters - ---------- - funcs: - List of functions to be injected. Must be annotated with type hints. - """ - return Container( - self._injector.create_child_injector([_injectable(f) for f in funcs]), - lazy=self._lazy, - ) - def _delayed(func: Callable) -> Callable: """ diff --git a/tests/container_test.py b/tests/container_test.py index 5167731a..b1ce5fac 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import Callable, List, NewType +from typing import Callable, List, NewType, Tuple import dask import pytest @@ -74,112 +74,7 @@ def provide_int() -> int: assert ncall == 1 -def test_make_child_container_inherits_bindings_from_parent(): - container = sl.make_container([int_to_float, make_int]) - child = container.make_child_container([int_float_to_str]) - assert child.get(str) == "3;1.5" - - -def test_make_child_container_override_parent_binding(): - def other_int() -> int: - return 4 - - container = sl.make_container([make_int]) - child = container.make_child_container([other_int, int_to_float, int_float_to_str]) - assert child.get(str) == "4;2.0" - assert child.get(int) == 4 - - -def test_make_child_container_override_does_not_affect_transitive_dependency(): - def other_int() -> int: - return 4 - - container = sl.make_container([int_to_float, make_int]) - child = container.make_child_container([other_int, int_float_to_str]) - assert child.get(int) == 4 - # Float computed within parent from int=3, not from other_int=4 - assert child.get(str) == "4;1.5" - - -def test_make_child_container_can_provide_transitive_dependency(): - # child injector must contain copy, we cannot use parent injector - # to provide "template" for child injector - # should we support overriding at all? - - # We had the following example in the design document: - - # import injector - # from typing import NewType - # - # BackgroundContainer = NewType('BackgroundContainer', injector.Injector) - # - # class BackgroundModule(injector.Module): - # def __init__(self, container: injector.Injector): - # self._container = container - # - # @injector.provide - # def get_background_iofq(self) -> BackgroundIofQ: - # return BackgroundIofQ(self._container.get(IofQ)) - # - # container = injector.Injector() # Configure with common reduction modules - # background = container.create_child_injector(background_config) - # background_module = BackgroundModule(background) - # container.binder.install(background_module) - - # However, it seems like this would not work? We need to create the child injector - # with all the relevant function, the parent may only provide inputs to those funcs. - # Should `Container` support a registry of templates that can be used in children? - # Or do we require manual handling? - - container = sl.make_container([int_to_float]) - child = container.make_child_container([make_int]) - assert child.get(float) == 1.5 - - -def test_make_container_with_callable_that_uses_child(): - parent = sl.make_container([int_to_float, make_int]) - child = parent.make_child_container([int_float_to_str]) - - MyStr = NewType('MyStr', str) - - def use_child() -> MyStr: - return MyStr(child.get(str)) - - container = sl.make_container([use_child]) - - assert container.get(MyStr) == "3;1.5" - - -def test_make_container_with_multiple_children(): - parent = sl.make_container([make_int]) - - def float1() -> float: - return 1.5 - - def float2() -> float: - return 2.5 - - child1 = parent.make_child_container([float1, int_float_to_str]) - child2 = parent.make_child_container([float2, int_float_to_str]) - Str1 = NewType('Str1', str) - Str2 = NewType('Str2', str) - - def get_str1() -> Str1: - return Str1(child1.get(str)) - - def get_str2() -> Str2: - return Str2(child2.get(str)) - - def use_strings(s1: Str1, s2: Str2) -> str: - return f"{s1};{s2}" - - container = sl.make_container([get_str1, get_str2, use_strings]) - assert container.get(Str1) == "3;1.5" - assert container.get(Str2) == "3;2.5" - assert container.get(str) == "3;1.5;3;2.5" - - -def test_make_container_with_multiple_children_does_not_repeat_calls(): +def test_make_container_with_subgraph_template(): ncall = 0 def provide_int() -> int: @@ -187,45 +82,40 @@ def provide_int() -> int: ncall += 1 return 3 - parent = sl.make_container([provide_int], lazy=True) + Float = sl.parametrized_domain_type('Float', float) + Str = sl.parametrized_domain_type('Str', str) - def float1() -> float: - return 1.5 + def child(tp: type) -> List[Callable]: + def int_float_to_str(x: int, y: Float[tp]) -> Str[tp]: + return Str[tp](f"{x};{y}") + + return [int_float_to_str] - def float2() -> float: - return 2.5 - - child1 = parent.make_child_container([float1, int_float_to_str]) - child2 = parent.make_child_container([float2, int_float_to_str]) - Str1 = NewType('Str1', str) - Str2 = NewType('Str2', str) - - def get_str1() -> Str1: - # Would need to call compute() here, but then we would not get a single graph - # Otherwise we delay a Delayed - # If we use lazy=False, then this computes too early. - return Str1(child1.get(str)) - - def get_str2() -> Str2: - # Only works by coincidence, since Str2(Delayed) works. If we used - # a function that does not take a Delayed, this would fail, e.g., sc.sin. - return Str2(child2.get(str)) - - def use_strings(s1: Str1, s2: Str2) -> str: - return f"{s1};{s2}" - - container = sl.make_container([use_strings], lazy=True) - # If we wrap with _injectable, things won't work. Need special handling. How? - container._injector.binder.bind(Str1, get_str1) - container._injector.binder.bind(Str2, get_str2) - assert container.get(str).compute() == "3;1.5;3;2.5" + Run1 = NewType('Run1', int) + Run2 = NewType('Run2', int) + Result = NewType('Result', str) + + def float1() -> Float[Run1]: + return Float[Run1](1.5) + + def float2() -> Float[Run2]: + return Float[Run2](2.5) + + def use_strings(s1: Str[Run1], s2: Str[Run2]) -> Result: + return Result(f"{s1};{s2}") + + container = sl.make_container( + [provide_int, float1, float2, use_strings] + child(Run1) + child(Run2), + lazy=True, + ) + assert container.get(Result).compute() == "3;1.5;3;2.5" assert ncall == 1 Str = sl.parametrized_domain_type('Str', str) -def templated(tp: type) -> List[Callable]: +def subworkflow(tp: type) -> List[Callable]: # Could also provide option for a list of types # How can a user extend this? Just make there own wrapper function? def f(x: tp) -> Str[tp]: @@ -234,6 +124,14 @@ def f(x: tp) -> Str[tp]: return [f] +def from_templates( + template: Callable[[type], List[Callable]], tps: Tuple[type, ...] +) -> List[Callable]: + import itertools + + return list(itertools.chain.from_iterable(template(tp) for tp in tps)) + + def test_container_from_templated(): def make_float() -> float: return 1.5 @@ -242,7 +140,7 @@ def combine(x: Str[int], y: Str[float]) -> str: return f"{x};{y}" container = sl.make_container( - [make_int, make_float, combine] + templated(int) + templated(float) + [make_int, make_float, combine] + from_templates(subworkflow, (int, float)) ) assert container.get(Str[int]) == '3' assert container.get(Str[float]) == '1.5' From f651c235fa17e0bce1e88feaffb18f5c94b47aee Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 7 Jul 2023 10:23:36 +0200 Subject: [PATCH 05/43] Cleanup --- tests/container_test.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/container_test.py b/tests/container_test.py index b1ce5fac..c3d83331 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import Callable, List, NewType, Tuple +from typing import Callable, List, NewType import dask import pytest @@ -116,22 +116,12 @@ def use_strings(s1: Str[Run1], s2: Str[Run2]) -> Result: def subworkflow(tp: type) -> List[Callable]: - # Could also provide option for a list of types - # How can a user extend this? Just make there own wrapper function? def f(x: tp) -> Str[tp]: return Str[tp](x) return [f] -def from_templates( - template: Callable[[type], List[Callable]], tps: Tuple[type, ...] -) -> List[Callable]: - import itertools - - return list(itertools.chain.from_iterable(template(tp) for tp in tps)) - - def test_container_from_templated(): def make_float() -> float: return 1.5 @@ -140,7 +130,7 @@ def combine(x: Str[int], y: Str[float]) -> str: return f"{x};{y}" container = sl.make_container( - [make_int, make_float, combine] + from_templates(subworkflow, (int, float)) + [make_int, make_float, combine] + subworkflow(int) + subworkflow(float) ) assert container.get(Str[int]) == '3' assert container.get(Str[float]) == '1.5' From a232787fe779d36617227ab49d8d33df7eeee780 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 10 Jul 2023 10:55:04 +0200 Subject: [PATCH 06/43] Properly working parametrized_domain_type --- src/sciline/__init__.py | 2 +- src/sciline/domain.py | 45 ++++++++------ tests/complex_workflow_test.py | 105 +++++++++++++++++++++++++++++++++ tests/container_test.py | 2 +- 4 files changed, 135 insertions(+), 19 deletions(-) create mode 100644 tests/complex_workflow_test.py diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py index a8a2dbb5..37e6474c 100644 --- a/src/sciline/__init__.py +++ b/src/sciline/__init__.py @@ -10,4 +10,4 @@ __version__ = "0.0.0" from .container import Container, UnsatisfiedRequirement, make_container -from .domain import domain_type, parametrized_domain_type +from .domain import parametrized_domain_type diff --git a/src/sciline/domain.py b/src/sciline/domain.py index fd74485a..8c53b66c 100644 --- a/src/sciline/domain.py +++ b/src/sciline/domain.py @@ -1,22 +1,33 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from __future__ import annotations - -import typing -from typing import Callable, Generic, List, Type, TypeVar, Union - - -def domain_type(name: str, base: type) -> type: - class tp(base): - pass - - return tp +from typing import Dict, NewType def parametrized_domain_type(name: str, base: type) -> type: - T = TypeVar('T') - - class tp(base, Generic[T]): - pass - - return tp + """ + Return a type-factory for parametrized domain types. + + The types return by the factory are created using typing.NewType. The returned + factory is used similarly to a Generic, but note that the factory itself should + not be used for annotations. + + Parameters + ---------- + name: + The name of the type. This is used as a prefix for the names of the types + returned by the factory. + base: + The base type of the types returned by the factory. + """ + + class Factory: + _subtypes: Dict[str, type] = {} + + def __class_getitem__(cls, tp: type) -> type: + key = f'{name}_{tp.__name__}' + if (t := cls._subtypes.get(key)) is None: + t = NewType(key, base) + cls._subtypes[key] = t + return t + + return Factory diff --git a/tests/complex_workflow_test.py b/tests/complex_workflow_test.py new file mode 100644 index 00000000..adec824a --- /dev/null +++ b/tests/complex_workflow_test.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from dataclasses import dataclass +from typing import Callable, List, NewType + +import dask +import numpy as np + +import sciline as sl + +# We use dask with a single thread, to ensure that call counting below is correct. +dask.config.set(scheduler='synchronous') + + +@dataclass +class RawData: + data: np.ndarray + monitor1: float + monitor2: float + + +SampleRun = NewType('SampleRun', int) +BackgroundRun = NewType('BackgroundRun', int) +DetectorMask = NewType('DetectorMask', np.ndarray) +DirectBeam = NewType('DirectBeam', np.ndarray) +SolidAngle = NewType('SolidAngle', np.ndarray) +Raw = sl.parametrized_domain_type('Raw', RawData) +Masked = sl.parametrized_domain_type('Masked', np.ndarray) +IncidentMonitor = sl.parametrized_domain_type('IncidentMonitor', float) +TransmissionMonitor = sl.parametrized_domain_type('TransmissionMonitor', float) +TransmissionFraction = sl.parametrized_domain_type('TransmissionFraction', float) +IofQ = sl.parametrized_domain_type('IofQ', np.ndarray) +BackgroundSubtractedIofQ = NewType('BackgroundSubtractedIofQ', np.ndarray) + + +def reduction_factory(tp: type) -> List[Callable]: + def incident_monitor(x: Raw[tp]) -> IncidentMonitor[tp]: + return IncidentMonitor[tp](x.monitor1) + + def transmission_monitor(x: Raw[tp]) -> TransmissionMonitor[tp]: + return TransmissionMonitor[tp](x.monitor2) + + def mask_detector(x: Raw[tp], mask: DetectorMask) -> Masked[tp]: + return Masked[tp](x.data * mask) + + def transmission( + incident: IncidentMonitor[tp], transmission: TransmissionMonitor[tp] + ) -> TransmissionFraction[tp]: + return TransmissionFraction[tp](incident / transmission) + + def iofq( + x: Masked[tp], + solid_angle: SolidAngle, + direct_beam: DirectBeam, + transmission: TransmissionFraction[tp], + ) -> IofQ[tp]: + return IofQ[tp](x / (solid_angle * direct_beam * transmission)) + + return [incident_monitor, transmission_monitor, mask_detector, transmission, iofq] + + +def raw_sample() -> Raw[SampleRun]: + return Raw[SampleRun](RawData(data=np.ones(4), monitor1=1.0, monitor2=2.0)) + + +def raw_background() -> Raw[BackgroundRun]: + return Raw[BackgroundRun]( + RawData(data=np.ones(4) * 1.5, monitor1=1.0, monitor2=4.0) + ) + + +def detector_mask() -> DetectorMask: + return DetectorMask(np.array([1, 1, 0, 1])) + + +def solid_angle() -> SolidAngle: + return SolidAngle(np.array([1.0, 0.5, 0.25, 0.125])) + + +def direct_beam() -> DirectBeam: + return DirectBeam(np.array(1 / 1.5)) + + +def subtract_background( + sample: IofQ[SampleRun], background: IofQ[BackgroundRun] +) -> BackgroundSubtractedIofQ: + return BackgroundSubtractedIofQ(sample - background) + + +def test_reduction_workflow(): + container = sl.make_container( + [ + raw_sample, + raw_background, + detector_mask, + solid_angle, + direct_beam, + subtract_background, + ] + + reduction_factory(SampleRun) + + reduction_factory(BackgroundRun) + ) + assert np.array_equal(container.get(IofQ[SampleRun]), [3, 6, 0, 24]) + assert np.array_equal(container.get(IofQ[BackgroundRun]), [9, 18, 0, 72]) + assert np.array_equal(container.get(BackgroundSubtractedIofQ), [-6, -12, 0, -48]) diff --git a/tests/container_test.py b/tests/container_test.py index c3d83331..0ef96f84 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -117,7 +117,7 @@ def use_strings(s1: Str[Run1], s2: Str[Run2]) -> Result: def subworkflow(tp: type) -> List[Callable]: def f(x: tp) -> Str[tp]: - return Str[tp](x) + return Str[tp](f'{x}') return [f] From 7912185a26bc422a305e442ef1b0e0d4019d1663 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 10 Jul 2023 11:10:49 +0200 Subject: [PATCH 07/43] Add NumPy to test requirements --- requirements/test.in | 1 + requirements/test.txt | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/requirements/test.in b/requirements/test.in index 1cf404d7..9064cbf2 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -1,2 +1,3 @@ -r base.in +numpy pytest diff --git a/requirements/test.txt b/requirements/test.txt index 62c51aba..320e3ac2 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,4 +1,4 @@ -# SHA1:a035a60fcbac4cd7bf595dbd81ee7994505d4a95 +# SHA1:3720d9b18830e4fcedb827ec36f6808035c1ea2c # # This file is autogenerated by pip-compile-multi # To update, run: @@ -10,6 +10,8 @@ exceptiongroup==1.1.1 # via pytest iniconfig==2.0.0 # via pytest +numpy==1.24.4 + # via -r test.in pluggy==1.0.0 # via pytest pytest==7.3.1 From fbf47ccbcaff9f6deef7033f3f94a9f615d8b014 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 11 Jul 2023 10:02:56 +0200 Subject: [PATCH 08/43] Exploring different solutions to mypy problems --- src/sciline/__init__.py | 2 +- src/sciline/domain.py | 202 ++++++++++++++++++++++++++++++++++++++-- tests/domain_test.py | 17 ++++ tests/mypy_test.py | 191 +++++++++++++++++++++++++++++++++++++ 4 files changed, 401 insertions(+), 11 deletions(-) create mode 100644 tests/domain_test.py create mode 100644 tests/mypy_test.py diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py index 37e6474c..366cf2d5 100644 --- a/src/sciline/__init__.py +++ b/src/sciline/__init__.py @@ -10,4 +10,4 @@ __version__ = "0.0.0" from .container import Container, UnsatisfiedRequirement, make_container -from .domain import parametrized_domain_type +from .domain import parametrized_domain_type, DomainTypeFactory diff --git a/src/sciline/domain.py b/src/sciline/domain.py index 8c53b66c..89dec76b 100644 --- a/src/sciline/domain.py +++ b/src/sciline/domain.py @@ -1,9 +1,59 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import Dict, NewType +from typing import Dict, NewType, Any, Generic, TypeVar +import typing +import functools + + +class DomainTypeFactory: + def __init__(self, name: str, base: type) -> None: + self._name: str = name + self._base: type = base + self._subtypes: Dict[str, NewType] = {} + + def __getitem__(self, tp: type) -> type: + return self(tp) + + def __call__(self, *args: Any, **kwargs: Any) -> NewType: + key = f'{self._name}' + for arg in args: + key += f'_{arg}' + for k, v in kwargs.items(): + key += f'_{k}_{v}' + if (t := self._subtypes.get(key)) is not None: + return t + t = NewType(key, self._base) + self._subtypes[key] = t + return t + + +T = TypeVar("T") + + +class SingleParameterStr(str, Generic[T]): + def __new__(cls, x: str): + assert isinstance(x, str) + return x + + +class SingleParameterFloat(float, Generic[T]): + def __new__(cls, x: float): + assert isinstance(x, float) + return x def parametrized_domain_type(name: str, base: type) -> type: + if base is str: + + class DomainType(SingleParameterStr[T]): + ... + + if base is float: + + class DomainType(SingleParameterFloat[T]): + ... + + return DomainType """ Return a type-factory for parametrized domain types. @@ -19,15 +69,147 @@ def parametrized_domain_type(name: str, base: type) -> type: base: The base type of the types returned by the factory. """ + return DomainTypeFactory(name, base) - class Factory: - _subtypes: Dict[str, type] = {} + # class Factory: + # _subtypes: Dict[str, type] = {} - def __class_getitem__(cls, tp: type) -> type: - key = f'{name}_{tp.__name__}' - if (t := cls._subtypes.get(key)) is None: - t = NewType(key, base) - cls._subtypes[key] = t - return t + # def __class_getitem__(cls, tp: type) -> type: + # key = f'{name}_{tp.__name__}' + # if (t := cls._subtypes.get(key)) is None: + # t = NewType(key, base) + # cls._subtypes[key] = t + # return t + + # return Factory + + +_cleanups = [] + + +def _tp_cache(func=None, /, *, typed=False): + """Internal wrapper caching __getitem__ of generic types with a fallback to + original function for non-hashable arguments. + """ + + def decorator(func): + cached = functools.lru_cache(typed=typed)(func) + _cleanups.append(cached.cache_clear) + + @functools.wraps(func) + def inner(*args, **kwds): + try: + return cached(*args, **kwds) + except TypeError: + pass # All real errors (not unhashable args) are raised below. + return func(*args, **kwds) + + return inner + + if func is not None: + return decorator(func) + + return decorator + + +class _Immutable: + """Mixin to indicate that object should not be copied.""" + + __slots__ = () + + def __copy__(self): + return self + + def __deepcopy__(self, memo): + return self + + +class NewGenericType(_Immutable): + def __init__(self, name, tp, *, _tvars=()): + self.__qualname__ = name + if '.' in name: + name = name.rpartition('.')[-1] + self.__name__ = name + self.__supertype__ = tp + self.__parameters__ = _tvars + def_mod = typing._caller() + if def_mod != 'typing': + self.__module__ = def_mod + + @_tp_cache + def __class_getitem__(cls, params): + # copied from Generic.__class_getitem__ + if not isinstance(params, tuple): + params = (params,) + if not params: + raise TypeError( + f"Parameter list to {cls.__qualname__}[...] cannot be empty" + ) + params = tuple(typing._type_convert(p) for p in params) + if not all(isinstance(p, (typing.TypeVar, typing.ParamSpec)) for p in params): + raise TypeError( + f"Parameters to {cls.__name__}[...] must all be type variables " + f"or parameter specification variables." + ) + if len(set(params)) != len(params): + raise TypeError(f"Parameters to {cls.__name__}[...] must all be unique") + return functools.partial(cls, _tvars=params) + + @_tp_cache + def __getitem__(self, params): + # copied from typing.Generic.__class_getitem__ + if not isinstance(params, tuple): + params = (params,) + params = tuple(typing._type_convert(p) for p in params) + if any(isinstance(t, typing.ParamSpec) for t in self.__parameters__): + params = typing._prepare_paramspec_params(self, params) + else: + typing._check_generic(self, params, len(self.__parameters__)) + return typing._GenericAlias( + self, + params, + _typevar_types=(typing.TypeVar, typing.ParamSpec), + _paramspec_tvars=True, + ) + + def __repr__(self): + return f'{self.__module__}.{self.__qualname__}' + + def __call__(self, x): + return x + + def __reduce__(self): + return self.__qualname__ + + def __or__(self, other): + return typing.Union[self, other] + + def __ror__(self, other): + return typing.Union[other, self] + + +class MyNewType: + def __init__(self, name, tp): + self.__qualname__ = name + if '.' in name: + name = name.rpartition('.')[-1] + self.__name__ = name + self.__supertype__ = tp + def_mod = typing._caller() + if def_mod != 'typing': + self.__module__ = def_mod + + def __repr__(self): + return f'{self.__module__}.{self.__qualname__}' + + def __call__(self, x): + return x + + def __reduce__(self): + return self.__qualname__ + + def __or__(self, other): + return typing.Union[self, other] - return Factory + def __ror__(self, other): + return typing.Union[other, self] diff --git a/tests/domain_test.py b/tests/domain_test.py new file mode 100644 index 00000000..45937cd0 --- /dev/null +++ b/tests/domain_test.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +import sciline as sl +import typing + + +def test_domain_type_factory(): + Str = sl.DomainTypeFactory('Str', str) + tp = Str('a', b=1) + assert tp is tp + assert tp == tp + + +def test_NewGenericType(): + T = typing.TypeVar('T') + Str = sl.domain.NewGenericType[T]('Str', str) + assert Str[int] == Str[int] diff --git a/tests/mypy_test.py b/tests/mypy_test.py new file mode 100644 index 00000000..ba4d5886 --- /dev/null +++ b/tests/mypy_test.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from __future__ import annotations +import typing +import sciline as sl +import numpy as np + + +def factory() -> typing.Callable: + T = typing.TypeVar('T') + + def func(x: T) -> typing.List[T]: + return [x, x] + + return func + + +def test_factory(): + f = factory() + typing.get_type_hints(f)['return'] + assert f(1) == [1, 1] + + +def test_providers(): + providers = {} + f = factory() + for tp in (int, float): + generic = typing.get_type_hints(f)['return'] + # replace the typevar with the concrete type + concrete = generic[tp] + providers[concrete] = f + + def func(arg: typing.List[int]) -> int: + return sum(arg) + + assert func(providers.get(typing.get_type_hints(func)['arg'])(1)) == 2 + + +T = typing.TypeVar('T') + + +# Str = sl.domain.MyNewType('Str', str) +# Str = typing.NewType('Str', str) + + +class Str(typing.Generic[T]): + def __init__(self, a): + self.__a = a + + def __getattr__(self, attr): + return getattr(self.__a, attr) + + def __setattr__(self, attr, val): + if attr == '_Str__a': + object.__setattr__(self, attr, val) + + return setattr(self.__a, attr, val) + + +def str_factory() -> typing.Callable: + T = typing.TypeVar('T') + + def func(x: T) -> Str[T]: + return Str(f'{x}') + + return func + + +def test_Str(): + f = str_factory() + assert f(1) == '1' + + +def make_domain_type( + base, types: typing.Tuple[type, ...] +) -> typing.Dict[type, typing.NewType]: + return {tp: typing.NewType(f'IofQ_{tp.__name__}', base) for tp in types} + + +def test_make_domain_type(): + Raw = make_domain_type(list, (int, float)) + IofQ = make_domain_type(str, (int, float)) + + T = typing.TypeVar('T') + + def func(x: Raw[T]) -> IofQ[T]: + return f'{x}' + + assert func(Raw[int]([1, 2])) == '[1, 2]' + assert func(1.0) == '1.0' + + +def test_wrapping(): + T = typing.TypeVar('T') + + class Raw(typing.Generic[T]): + def __init__(self, value: list): + self.value: list = value + + class IofQ(typing.Generic[T]): + def __init__(self, value: str): + self.value: str = value + + def factory(tp: type) -> typing.Callable: + class DomainType(typing.Generic[T]): + def __init__(self, value: tp): + self.value: str = value + + def func(x: Raw[T]) -> IofQ[T]: + return IofQ[T](f'{x.value}') + + assert func(Raw[int]([1])).value == '[1]' + assert func(Raw[float]([1.0])).value == '[1.0]' + + +from typing import ( + get_args, + get_origin, + get_type_hints, + Dict, + Any, + Callable, + Optional, + Generic, + TypeVar, +) + + +class SingleParameterGeneric(np.ndarray, Generic[T]): + def __new__(cls, x: np.ndarray): + assert isinstance(x, np.ndarray) + return x + + +class DomainType(SingleParameterGeneric[T]): + ... + + +class AnotherDomainType(SingleParameterGeneric[T]): + ... + + +DataType = TypeVar("DataType") + + +def foo(data: DomainType[DataType]) -> AnotherDomainType[DataType]: + return AnotherDomainType(data + 1) + + +def test_foo() -> None: + assert np.array_equal(foo(DomainType(np.array([1, 2, 3]))), [1, 2, 3]) + a = np.array([1, 2, 3]) + assert DomainType(a) is a + + Array = typing.NewType("Array", np.ndarray) + assert Array(a) is a + + +def func(x: AnotherDomainType[int]) -> int: + return x[1] + + +def make_int() -> DomainType[int]: + return DomainType(np.array([1, 2, 3])) + + +def test_injection() -> None: + providers: Dict[type, Callable[..., Any]] = { + int: func, + DomainType[int]: make_int, + AnotherDomainType: foo, + } + + Return = typing.TypeVar("Return") + + def call(func: Callable[..., Return], bound: Optional[Any] = None) -> Return: + tps = get_type_hints(func) + del tps['return'] + args: Dict[str, Any] = {} + for name, tp in tps.items(): + if (provider := providers.get(tp)) is not None: + args[name] = call(provider, bound) + elif (origin := get_origin(tp)) is not None: + if (provider := providers.get(origin)) is not None: + args[name] = call(provider, *get_args(tp)) + else: + provider = providers[origin[bound]] + args[name] = call(provider, bound) + return func(**args) + + assert call(func) == 3 From 7acfffb1ed6335fa6cfd8cb3d35093c34c4d6b3b Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 11 Jul 2023 13:02:19 +0200 Subject: [PATCH 09/43] Begin refactor --- src/sciline/__init__.py | 4 +- src/sciline/container.py | 64 +++++++++++++++++++- src/sciline/domain.py | 14 +++-- tests/complex_workflow_test.py | 104 +++++++++++++++++++++------------ tests/mypy_test.py | 25 ++++---- 5 files changed, 154 insertions(+), 57 deletions(-) diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py index 366cf2d5..b1de1824 100644 --- a/src/sciline/__init__.py +++ b/src/sciline/__init__.py @@ -9,5 +9,5 @@ except importlib.metadata.PackageNotFoundError: __version__ = "0.0.0" -from .container import Container, UnsatisfiedRequirement, make_container -from .domain import parametrized_domain_type, DomainTypeFactory +from .container import Container, Container2, UnsatisfiedRequirement, make_container +from .domain import DomainTypeFactory, Scope, parametrized_domain_type diff --git a/src/sciline/container.py b/src/sciline/container.py index 4253be73..e0df8e94 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -4,7 +4,19 @@ import typing from functools import wraps -from typing import Callable, List, Type, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Type, + TypeVar, + Union, + get_args, + get_origin, + get_type_hints, +) import injector from dask.delayed import Delayed @@ -16,6 +28,56 @@ class UnsatisfiedRequirement(Exception): pass +class Container2: + def __init__(self): + self._providers: Dict[type, Callable[..., Any]] = {} + + def insert(self, provider: Callable[..., Any]): + key = get_type_hints(provider)['return'] + if (origin := get_origin(key)) is not None: + args = get_args(key) + if len(args) != 1: + raise ValueError(f'Cannot handle {key} with more than 1 argument') + key = origin if isinstance(args[0], TypeVar) else key + if key in self._providers: + raise ValueError(f'Provider for {key} already exists') + self._providers[key] = provider + + Return = typing.TypeVar("Return") + + def call(self, func: Callable[..., Return], bound: Optional[Any] = None) -> Return: + print('call', func, bound) + tps = get_type_hints(func) + del tps['return'] + args: Dict[str, Any] = {} + for name, tp in tps.items(): + args[name] = self._get(tp, bound=bound) + return func(**args) + + def _get(self, tp, bound: Optional[type] = None): + print('_get', tp, bound) + if (provider := self._providers.get(tp)) is not None: + return self.call(provider, bound) + elif (origin := get_origin(tp)) is not None: + if (provider := self._providers.get(origin)) is not None: + return self.call(provider, get_args(tp)[0] if bound is None else bound) + else: + provider = self._providers[origin[bound]] + return self.call(provider, bound) + + def get(self, tp: Type[T], /) -> Union[T, Delayed]: + try: + # We are slightly abusing Python's type system here, by using the + # injector to get T, but actually it returns a Delayed that can + # compute T. self._injector does not know this due to how we setup the + # bindings. We'd like to use Delayed[T], but that is not supported yet: + # https://github.com/dask/dask/pull/9256 + task: Delayed = self._get(tp) # type: ignore + except injector.UnsatisfiedRequirement as e: + raise UnsatisfiedRequirement(e) from e + return task # if self._lazy else task.compute() + + class Container: def __init__(self, inj: injector.Injector, /, *, lazy: bool) -> None: self._injector = inj diff --git a/src/sciline/domain.py b/src/sciline/domain.py index 89dec76b..c4b2c81f 100644 --- a/src/sciline/domain.py +++ b/src/sciline/domain.py @@ -1,8 +1,15 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import Dict, NewType, Any, Generic, TypeVar -import typing import functools +import typing +from typing import Any, Dict, Generic, NewType, TypeVar + +T = TypeVar("T") + + +class Scope(Generic[T]): + def __new__(cls, x): + return x class DomainTypeFactory: @@ -27,9 +34,6 @@ def __call__(self, *args: Any, **kwargs: Any) -> NewType: return t -T = TypeVar("T") - - class SingleParameterStr(str, Generic[T]): def __new__(cls, x: str): assert isinstance(x, str) diff --git a/tests/complex_workflow_test.py b/tests/complex_workflow_test.py index adec824a..0a72d711 100644 --- a/tests/complex_workflow_test.py +++ b/tests/complex_workflow_test.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from dataclasses import dataclass -from typing import Callable, List, NewType +from typing import NewType, TypeVar import dask import numpy as np @@ -24,39 +24,65 @@ class RawData: DetectorMask = NewType('DetectorMask', np.ndarray) DirectBeam = NewType('DirectBeam', np.ndarray) SolidAngle = NewType('SolidAngle', np.ndarray) -Raw = sl.parametrized_domain_type('Raw', RawData) -Masked = sl.parametrized_domain_type('Masked', np.ndarray) -IncidentMonitor = sl.parametrized_domain_type('IncidentMonitor', float) -TransmissionMonitor = sl.parametrized_domain_type('TransmissionMonitor', float) -TransmissionFraction = sl.parametrized_domain_type('TransmissionFraction', float) -IofQ = sl.parametrized_domain_type('IofQ', np.ndarray) + +Run = TypeVar('Run') + + +class Raw(sl.Scope[Run], RawData): + ... + + +class Masked(sl.Scope[Run], np.ndarray): + ... + + +class IncidentMonitor(sl.Scope[Run], float): + ... + + +class TransmissionMonitor(sl.Scope[Run], float): + ... + + +class TransmissionFraction(sl.Scope[Run], float): + ... + + +class IofQ(sl.Scope[Run], np.ndarray): + ... + + BackgroundSubtractedIofQ = NewType('BackgroundSubtractedIofQ', np.ndarray) -def reduction_factory(tp: type) -> List[Callable]: - def incident_monitor(x: Raw[tp]) -> IncidentMonitor[tp]: - return IncidentMonitor[tp](x.monitor1) +def incident_monitor(x: Raw[Run]) -> IncidentMonitor[Run]: + return IncidentMonitor(x.monitor1) + + +def transmission_monitor(x: Raw[Run]) -> TransmissionMonitor[Run]: + return TransmissionMonitor(x.monitor2) - def transmission_monitor(x: Raw[tp]) -> TransmissionMonitor[tp]: - return TransmissionMonitor[tp](x.monitor2) - def mask_detector(x: Raw[tp], mask: DetectorMask) -> Masked[tp]: - return Masked[tp](x.data * mask) +def mask_detector(x: Raw[Run], mask: DetectorMask) -> Masked[Run]: + return Masked(x.data * mask) - def transmission( - incident: IncidentMonitor[tp], transmission: TransmissionMonitor[tp] - ) -> TransmissionFraction[tp]: - return TransmissionFraction[tp](incident / transmission) - def iofq( - x: Masked[tp], - solid_angle: SolidAngle, - direct_beam: DirectBeam, - transmission: TransmissionFraction[tp], - ) -> IofQ[tp]: - return IofQ[tp](x / (solid_angle * direct_beam * transmission)) +def transmission( + incident: IncidentMonitor[Run], transmission: TransmissionMonitor[Run] +) -> TransmissionFraction[Run]: + return TransmissionFraction(incident / transmission) - return [incident_monitor, transmission_monitor, mask_detector, transmission, iofq] + +def iofq( + x: Masked[Run], + solid_angle: SolidAngle, + direct_beam: DirectBeam, + transmission: TransmissionFraction[Run], +) -> IofQ[Run]: + return IofQ(x / (solid_angle * direct_beam * transmission)) + + +reduction = [incident_monitor, transmission_monitor, mask_detector, transmission, iofq] def raw_sample() -> Raw[SampleRun]: @@ -88,18 +114,20 @@ def subtract_background( def test_reduction_workflow(): - container = sl.make_container( - [ - raw_sample, - raw_background, - detector_mask, - solid_angle, - direct_beam, - subtract_background, - ] - + reduction_factory(SampleRun) - + reduction_factory(BackgroundRun) - ) + providers = [ + raw_sample, + raw_background, + detector_mask, + solid_angle, + direct_beam, + subtract_background, + ] + reduction + container = sl.Container2() + for p in providers: + container.insert(p) + + print(container._providers.keys()) + assert np.array_equal(container.get(IofQ[SampleRun]), [3, 6, 0, 24]) assert np.array_equal(container.get(IofQ[BackgroundRun]), [9, 18, 0, 72]) assert np.array_equal(container.get(BackgroundSubtractedIofQ), [-6, -12, 0, -48]) diff --git a/tests/mypy_test.py b/tests/mypy_test.py index ba4d5886..4cd25236 100644 --- a/tests/mypy_test.py +++ b/tests/mypy_test.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from __future__ import annotations + import typing -import sciline as sl + import numpy as np +import sciline as sl + def factory() -> typing.Callable: T = typing.TypeVar('T') @@ -114,29 +117,29 @@ def func(x: Raw[T]) -> IofQ[T]: from typing import ( - get_args, - get_origin, - get_type_hints, - Dict, Any, Callable, - Optional, + Dict, Generic, + Optional, TypeVar, + get_args, + get_origin, + get_type_hints, ) -class SingleParameterGeneric(np.ndarray, Generic[T]): +class SingleParameterGeneric(Generic[T]): def __new__(cls, x: np.ndarray): assert isinstance(x, np.ndarray) return x -class DomainType(SingleParameterGeneric[T]): +class DomainType(SingleParameterGeneric[T], np.ndarray): ... -class AnotherDomainType(SingleParameterGeneric[T]): +class AnotherDomainType(SingleParameterGeneric[T], np.ndarray): ... @@ -157,7 +160,7 @@ def test_foo() -> None: def func(x: AnotherDomainType[int]) -> int: - return x[1] + return np.sum(x) def make_int() -> DomainType[int]: @@ -188,4 +191,4 @@ def call(func: Callable[..., Return], bound: Optional[Any] = None) -> Return: args[name] = call(provider, bound) return func(**args) - assert call(func) == 3 + assert call(func) == 9 From b0c86a4b46facf853ea421a4a9c32130eddc576b Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 11 Jul 2023 13:08:44 +0200 Subject: [PATCH 10/43] Cleanup --- src/sciline/container.py | 46 +++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index e0df8e94..3990c10c 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -21,6 +21,23 @@ import injector from dask.delayed import Delayed + +def _delayed(func: Callable) -> Callable: + """ + Decorator to make a function return a delayed object. + + In contrast to dask.delayed, this uses functools.wraps, to preserve the + type hints, which is a prerequisite for injector to work. + """ + import dask + + @wraps(func) + def wrapper(*args, **kwargs): + return dask.delayed(func)(*args, **kwargs) + + return wrapper + + T = TypeVar('T') @@ -31,6 +48,7 @@ class UnsatisfiedRequirement(Exception): class Container2: def __init__(self): self._providers: Dict[type, Callable[..., Any]] = {} + self._lazy: bool = False def insert(self, provider: Callable[..., Any]): key = get_type_hints(provider)['return'] @@ -41,7 +59,7 @@ def insert(self, provider: Callable[..., Any]): key = origin if isinstance(args[0], TypeVar) else key if key in self._providers: raise ValueError(f'Provider for {key} already exists') - self._providers[key] = provider + self._providers[key] = _delayed(provider) Return = typing.TypeVar("Return") @@ -55,15 +73,19 @@ def call(self, func: Callable[..., Return], bound: Optional[Any] = None) -> Retu return func(**args) def _get(self, tp, bound: Optional[type] = None): - print('_get', tp, bound) if (provider := self._providers.get(tp)) is not None: return self.call(provider, bound) elif (origin := get_origin(tp)) is not None: if (provider := self._providers.get(origin)) is not None: - return self.call(provider, get_args(tp)[0] if bound is None else bound) + # TODO We would really need to support multiple bound params properly + param = get_args(tp)[0] + return self.call( + provider, bound if isinstance(param, TypeVar) else param + ) else: provider = self._providers[origin[bound]] return self.call(provider, bound) + raise UnsatisfiedRequirement("No provider found for type", tp) def get(self, tp: Type[T], /) -> Union[T, Delayed]: try: @@ -75,7 +97,7 @@ def get(self, tp: Type[T], /) -> Union[T, Delayed]: task: Delayed = self._get(tp) # type: ignore except injector.UnsatisfiedRequirement as e: raise UnsatisfiedRequirement(e) from e - return task # if self._lazy else task.compute() + return task if self._lazy else task.compute() class Container: @@ -96,22 +118,6 @@ def get(self, tp: Type[T], /) -> Union[T, Delayed]: return task if self._lazy else task.compute() -def _delayed(func: Callable) -> Callable: - """ - Decorator to make a function return a delayed object. - - In contrast to dask.delayed, this uses functools.wraps, to preserve the - type hints, which is a prerequisite for injector to work. - """ - import dask - - @wraps(func) - def wrapper(*args, **kwargs): - return dask.delayed(func)(*args, **kwargs) - - return wrapper - - def _injectable(func: Callable) -> Callable: """ Wrap a regular function so it can be registered in an injector and have its From 7a7990e4cf4857b76a834adbb2b6b1b184f3a95c Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 11 Jul 2023 13:17:29 +0200 Subject: [PATCH 11/43] Remove old implementation --- src/sciline/__init__.py | 2 +- src/sciline/container.py | 112 +++++++++++---------------------- tests/complex_workflow_test.py | 24 ++++--- 3 files changed, 47 insertions(+), 91 deletions(-) diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py index b1de1824..210c5b79 100644 --- a/src/sciline/__init__.py +++ b/src/sciline/__init__.py @@ -9,5 +9,5 @@ except importlib.metadata.PackageNotFoundError: __version__ = "0.0.0" -from .container import Container, Container2, UnsatisfiedRequirement, make_container +from .container import Container, UnsatisfiedRequirement from .domain import DomainTypeFactory, Scope, parametrized_domain_type diff --git a/src/sciline/container.py b/src/sciline/container.py index 3990c10c..1635e652 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -18,7 +18,6 @@ get_type_hints, ) -import injector from dask.delayed import Delayed @@ -45,10 +44,23 @@ class UnsatisfiedRequirement(Exception): pass -class Container2: - def __init__(self): +class Container: + def __init__(self, funcs: List[Callable], /, *, lazy: bool = False): + """ + Create a :py:class:`Container` from a list of functions. + + Parameters + ---------- + funcs: + List of functions to be injected. Must be annotated with type hints. + lazy: + If True, the functions are wrapped in :py:func:`dask.delayed` before + being injected. This allows to build a dask graph from the container. + """ self._providers: Dict[type, Callable[..., Any]] = {} - self._lazy: bool = False + self._lazy: bool = lazy + for func in funcs: + self.insert(func) def insert(self, provider: Callable[..., Any]): key = get_type_hints(provider)['return'] @@ -73,6 +85,19 @@ def call(self, func: Callable[..., Return], bound: Optional[Any] = None) -> Retu return func(**args) def _get(self, tp, bound: Optional[type] = None): + # When building a workflow, there are two common problems: + # + # 1. Intermediate results are used more than once. + # 2. Intermediate results are large, so we generally do not want to keep them + # in memory longer than necessary. + # + # To address these problems, we can internally build a graph of tasks, instead + # of # directly creating dependencies between functions. Currently we use Dask + # for this. # The Container instance will automatically compute the task, + # unless it is marked as lazy. We therefore use singleton-scope (to ensure Dask + # will recognize the task as the same object) and also wrap the function in + # dask.delayed. + # TODO Add caching mechanism! if (provider := self._providers.get(tp)) is not None: return self.call(provider, bound) elif (origin := get_origin(tp)) is not None: @@ -88,77 +113,10 @@ def _get(self, tp, bound: Optional[type] = None): raise UnsatisfiedRequirement("No provider found for type", tp) def get(self, tp: Type[T], /) -> Union[T, Delayed]: - try: - # We are slightly abusing Python's type system here, by using the - # injector to get T, but actually it returns a Delayed that can - # compute T. self._injector does not know this due to how we setup the - # bindings. We'd like to use Delayed[T], but that is not supported yet: - # https://github.com/dask/dask/pull/9256 - task: Delayed = self._get(tp) # type: ignore - except injector.UnsatisfiedRequirement as e: - raise UnsatisfiedRequirement(e) from e - return task if self._lazy else task.compute() - - -class Container: - def __init__(self, inj: injector.Injector, /, *, lazy: bool) -> None: - self._injector = inj - self._lazy = lazy - - def get(self, tp: Type[T], /) -> Union[T, Delayed]: - try: - # We are slightly abusing Python's type system here, by using the - # injector to get T, but actually it returns a Delayed that can - # compute T. self._injector does not know this due to how we setup the - # bindings. We'd like to use Delayed[T], but that is not supported yet: - # https://github.com/dask/dask/pull/9256 - task: Delayed = self._injector.get(tp) # type: ignore - except injector.UnsatisfiedRequirement as e: - raise UnsatisfiedRequirement(e) from e + # We are slightly abusing Python's type system here, by using the + # injector to get T, but actually it returns a Delayed that can + # compute T. self._injector does not know this due to how we setup the + # bindings. We'd like to use Delayed[T], but that is not supported yet: + # https://github.com/dask/dask/pull/9256 + task: Delayed = self._get(tp) # type: ignore return task if self._lazy else task.compute() - - -def _injectable(func: Callable) -> Callable: - """ - Wrap a regular function so it can be registered in an injector and have its - parameters injected. - """ - # When building a workflow, there are two common problems: - # - # 1. Intermediate results are used more than once. - # 2. Intermediate results are large, so we generally do not want to keep them - # in memory longer than necessary. - # - # To address these problems, we can internally build a graph of tasks, instead of - # directly creating dependencies between functions. Currently we use Dask for this. - # The Container instance will automatically compute the task, unless it is marked - # as lazy. We therefore use singleton-scope (to ensure Dask will recognize the - # task as the same object) and also wrap the function in dask.delayed. - scope = injector.singleton - func = _delayed(func) - tps = typing.get_type_hints(func) - - def bind(binder: injector.Binder): - binder.bind(tps['return'], injector.inject(func), scope=scope) - - return bind - - -def make_container(funcs: List[Callable], /, *, lazy: bool = False) -> Container: - """ - Create a :py:class:`Container` from a list of functions. - - Parameters - ---------- - funcs: - List of functions to be injected. Must be annotated with type hints. - lazy: - If True, the functions are wrapped in :py:func:`dask.delayed` before - being injected. This allows to build a dask graph from the container. - """ - # Note that we disable auto_bind, to ensure we do not accidentally bind to - # some default values. Everything must be explicit. - return Container( - injector.Injector([_injectable(f) for f in funcs], auto_bind=False), - lazy=lazy, - ) diff --git a/tests/complex_workflow_test.py b/tests/complex_workflow_test.py index 0a72d711..768de501 100644 --- a/tests/complex_workflow_test.py +++ b/tests/complex_workflow_test.py @@ -114,19 +114,17 @@ def subtract_background( def test_reduction_workflow(): - providers = [ - raw_sample, - raw_background, - detector_mask, - solid_angle, - direct_beam, - subtract_background, - ] + reduction - container = sl.Container2() - for p in providers: - container.insert(p) - - print(container._providers.keys()) + container = sl.Container( + [ + raw_sample, + raw_background, + detector_mask, + solid_angle, + direct_beam, + subtract_background, + ] + + reduction + ) assert np.array_equal(container.get(IofQ[SampleRun]), [3, 6, 0, 24]) assert np.array_equal(container.get(IofQ[BackgroundRun]), [9, 18, 0, 72]) From c8585026d2b23015699b60636ba69251d81e752a Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 11 Jul 2023 13:33:04 +0200 Subject: [PATCH 12/43] Refactor tests --- src/sciline/container.py | 23 +++++++++++++++---- tests/container_test.py | 49 +++++++++++++++++++--------------------- 2 files changed, 41 insertions(+), 31 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 1635e652..4e0096dd 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -59,6 +59,7 @@ def __init__(self, funcs: List[Callable], /, *, lazy: bool = False): """ self._providers: Dict[type, Callable[..., Any]] = {} self._lazy: bool = lazy + self._cache: Dict[type, Any] = {} for func in funcs: self.insert(func) @@ -85,6 +86,8 @@ def call(self, func: Callable[..., Return], bound: Optional[Any] = None) -> Retu return func(**args) def _get(self, tp, bound: Optional[type] = None): + if (cached := self._cache.get(tp)) is not None: + return cached # When building a workflow, there are two common problems: # # 1. Intermediate results are used more than once. @@ -97,19 +100,29 @@ def _get(self, tp, bound: Optional[type] = None): # unless it is marked as lazy. We therefore use singleton-scope (to ensure Dask # will recognize the task as the same object) and also wrap the function in # dask.delayed. - # TODO Add caching mechanism! if (provider := self._providers.get(tp)) is not None: - return self.call(provider, bound) - elif (origin := get_origin(tp)) is not None: + result = self.call(provider, bound) + self._cache[tp] = result + return result + elif (origin := get_origin(tp)) is None: + if (provider := self._providers.get(bound)) is not None: + result = self.call(provider) + self._cache[bound] = result + return result + else: if (provider := self._providers.get(origin)) is not None: # TODO We would really need to support multiple bound params properly param = get_args(tp)[0] - return self.call( + result = self.call( provider, bound if isinstance(param, TypeVar) else param ) + self._cache[tp] = result + return result else: provider = self._providers[origin[bound]] - return self.call(provider, bound) + result = self.call(provider, bound) + self._cache[origin[bound]] = result + return result raise UnsatisfiedRequirement("No provider found for type", tp) def get(self, tp: Type[T], /) -> Union[T, Delayed]: diff --git a/tests/container_test.py b/tests/container_test.py index 0ef96f84..9b8711e5 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import Callable, List, NewType +from typing import NewType, TypeVar import dask import pytest @@ -24,13 +24,13 @@ def int_float_to_str(x: int, y: float) -> str: def test_make_container_sets_up_working_container(): - container = sl.make_container([int_to_float, make_int]) + container = sl.Container([int_to_float, make_int]) assert container.get(float) == 1.5 assert container.get(int) == 3 def test_make_container_does_not_autobind(): - container = sl.make_container([int_to_float]) + container = sl.Container([int_to_float]) with pytest.raises(sl.UnsatisfiedRequirement): container.get(float) @@ -43,15 +43,13 @@ def provide_int() -> int: ncall += 1 return 3 - container = sl.make_container( - [int_to_float, provide_int, int_float_to_str], lazy=False - ) + container = sl.Container([int_to_float, provide_int, int_float_to_str], lazy=False) assert container.get(str) == "3;1.5" assert ncall == 1 def test_make_container_lazy_returns_task_that_computes_result(): - container = sl.make_container([int_to_float, make_int], lazy=True) + container = sl.Container([int_to_float, make_int], lazy=True) task = container.get(float) assert hasattr(task, 'compute') assert task.compute() == 1.5 @@ -65,9 +63,7 @@ def provide_int() -> int: ncall += 1 return 3 - container = sl.make_container( - [int_to_float, provide_int, int_float_to_str], lazy=True - ) + container = sl.Container([int_to_float, provide_int, int_float_to_str], lazy=True) task1 = container.get(float) task2 = container.get(str) assert dask.compute(task1, task2) == (1.5, '3;1.5') @@ -82,14 +78,16 @@ def provide_int() -> int: ncall += 1 return 3 - Float = sl.parametrized_domain_type('Float', float) - Str = sl.parametrized_domain_type('Str', str) + Param = TypeVar('Param') + + class Float(sl.Scope[Param], float): + ... - def child(tp: type) -> List[Callable]: - def int_float_to_str(x: int, y: Float[tp]) -> Str[tp]: - return Str[tp](f"{x};{y}") + class Str(sl.Scope[Param], str): + ... - return [int_float_to_str] + def int_float_to_str(x: int, y: Float[Param]) -> Str[Param]: + return Str(f"{x};{y}") Run1 = NewType('Run1', int) Run2 = NewType('Run2', int) @@ -104,22 +102,23 @@ def float2() -> Float[Run2]: def use_strings(s1: Str[Run1], s2: Str[Run2]) -> Result: return Result(f"{s1};{s2}") - container = sl.make_container( - [provide_int, float1, float2, use_strings] + child(Run1) + child(Run2), + container = sl.Container( + [provide_int, float1, float2, use_strings, int_float_to_str], lazy=True, ) assert container.get(Result).compute() == "3;1.5;3;2.5" assert ncall == 1 -Str = sl.parametrized_domain_type('Str', str) +Param = TypeVar('Param') -def subworkflow(tp: type) -> List[Callable]: - def f(x: tp) -> Str[tp]: - return Str[tp](f'{x}') +class Str(sl.Scope[Param], str): + ... - return [f] + +def f(x: Param) -> Str[Param]: + return Str(f'{x}') def test_container_from_templated(): @@ -129,9 +128,7 @@ def make_float() -> float: def combine(x: Str[int], y: Str[float]) -> str: return f"{x};{y}" - container = sl.make_container( - [make_int, make_float, combine] + subworkflow(int) + subworkflow(float) - ) + container = sl.Container([make_int, make_float, combine, f]) assert container.get(Str[int]) == '3' assert container.get(Str[float]) == '1.5' assert container.get(str) == '3;1.5' From 6a8afd15bdac67b478bcf7ed15e5b5e9c699d404 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 11 Jul 2023 13:37:41 +0200 Subject: [PATCH 13/43] Remove unused code --- src/sciline/__init__.py | 2 +- src/sciline/domain.py | 211 +--------------------------------------- 2 files changed, 2 insertions(+), 211 deletions(-) diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py index 210c5b79..07ac722f 100644 --- a/src/sciline/__init__.py +++ b/src/sciline/__init__.py @@ -10,4 +10,4 @@ __version__ = "0.0.0" from .container import Container, UnsatisfiedRequirement -from .domain import DomainTypeFactory, Scope, parametrized_domain_type +from .domain import Scope diff --git a/src/sciline/domain.py b/src/sciline/domain.py index c4b2c81f..a79727eb 100644 --- a/src/sciline/domain.py +++ b/src/sciline/domain.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -import functools -import typing -from typing import Any, Dict, Generic, NewType, TypeVar +from typing import Generic, TypeVar T = TypeVar("T") @@ -10,210 +8,3 @@ class Scope(Generic[T]): def __new__(cls, x): return x - - -class DomainTypeFactory: - def __init__(self, name: str, base: type) -> None: - self._name: str = name - self._base: type = base - self._subtypes: Dict[str, NewType] = {} - - def __getitem__(self, tp: type) -> type: - return self(tp) - - def __call__(self, *args: Any, **kwargs: Any) -> NewType: - key = f'{self._name}' - for arg in args: - key += f'_{arg}' - for k, v in kwargs.items(): - key += f'_{k}_{v}' - if (t := self._subtypes.get(key)) is not None: - return t - t = NewType(key, self._base) - self._subtypes[key] = t - return t - - -class SingleParameterStr(str, Generic[T]): - def __new__(cls, x: str): - assert isinstance(x, str) - return x - - -class SingleParameterFloat(float, Generic[T]): - def __new__(cls, x: float): - assert isinstance(x, float) - return x - - -def parametrized_domain_type(name: str, base: type) -> type: - if base is str: - - class DomainType(SingleParameterStr[T]): - ... - - if base is float: - - class DomainType(SingleParameterFloat[T]): - ... - - return DomainType - """ - Return a type-factory for parametrized domain types. - - The types return by the factory are created using typing.NewType. The returned - factory is used similarly to a Generic, but note that the factory itself should - not be used for annotations. - - Parameters - ---------- - name: - The name of the type. This is used as a prefix for the names of the types - returned by the factory. - base: - The base type of the types returned by the factory. - """ - return DomainTypeFactory(name, base) - - # class Factory: - # _subtypes: Dict[str, type] = {} - - # def __class_getitem__(cls, tp: type) -> type: - # key = f'{name}_{tp.__name__}' - # if (t := cls._subtypes.get(key)) is None: - # t = NewType(key, base) - # cls._subtypes[key] = t - # return t - - # return Factory - - -_cleanups = [] - - -def _tp_cache(func=None, /, *, typed=False): - """Internal wrapper caching __getitem__ of generic types with a fallback to - original function for non-hashable arguments. - """ - - def decorator(func): - cached = functools.lru_cache(typed=typed)(func) - _cleanups.append(cached.cache_clear) - - @functools.wraps(func) - def inner(*args, **kwds): - try: - return cached(*args, **kwds) - except TypeError: - pass # All real errors (not unhashable args) are raised below. - return func(*args, **kwds) - - return inner - - if func is not None: - return decorator(func) - - return decorator - - -class _Immutable: - """Mixin to indicate that object should not be copied.""" - - __slots__ = () - - def __copy__(self): - return self - - def __deepcopy__(self, memo): - return self - - -class NewGenericType(_Immutable): - def __init__(self, name, tp, *, _tvars=()): - self.__qualname__ = name - if '.' in name: - name = name.rpartition('.')[-1] - self.__name__ = name - self.__supertype__ = tp - self.__parameters__ = _tvars - def_mod = typing._caller() - if def_mod != 'typing': - self.__module__ = def_mod - - @_tp_cache - def __class_getitem__(cls, params): - # copied from Generic.__class_getitem__ - if not isinstance(params, tuple): - params = (params,) - if not params: - raise TypeError( - f"Parameter list to {cls.__qualname__}[...] cannot be empty" - ) - params = tuple(typing._type_convert(p) for p in params) - if not all(isinstance(p, (typing.TypeVar, typing.ParamSpec)) for p in params): - raise TypeError( - f"Parameters to {cls.__name__}[...] must all be type variables " - f"or parameter specification variables." - ) - if len(set(params)) != len(params): - raise TypeError(f"Parameters to {cls.__name__}[...] must all be unique") - return functools.partial(cls, _tvars=params) - - @_tp_cache - def __getitem__(self, params): - # copied from typing.Generic.__class_getitem__ - if not isinstance(params, tuple): - params = (params,) - params = tuple(typing._type_convert(p) for p in params) - if any(isinstance(t, typing.ParamSpec) for t in self.__parameters__): - params = typing._prepare_paramspec_params(self, params) - else: - typing._check_generic(self, params, len(self.__parameters__)) - return typing._GenericAlias( - self, - params, - _typevar_types=(typing.TypeVar, typing.ParamSpec), - _paramspec_tvars=True, - ) - - def __repr__(self): - return f'{self.__module__}.{self.__qualname__}' - - def __call__(self, x): - return x - - def __reduce__(self): - return self.__qualname__ - - def __or__(self, other): - return typing.Union[self, other] - - def __ror__(self, other): - return typing.Union[other, self] - - -class MyNewType: - def __init__(self, name, tp): - self.__qualname__ = name - if '.' in name: - name = name.rpartition('.')[-1] - self.__name__ = name - self.__supertype__ = tp - def_mod = typing._caller() - if def_mod != 'typing': - self.__module__ = def_mod - - def __repr__(self): - return f'{self.__module__}.{self.__qualname__}' - - def __call__(self, x): - return x - - def __reduce__(self): - return self.__qualname__ - - def __or__(self, other): - return typing.Union[self, other] - - def __ror__(self, other): - return typing.Union[other, self] From 1ff5f26c8c39b9721a933bd9a87ce291907e8a1c Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 11 Jul 2023 13:57:29 +0200 Subject: [PATCH 14/43] Simplify --- src/sciline/container.py | 50 +++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 4e0096dd..36317988 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -63,7 +63,7 @@ def __init__(self, funcs: List[Callable], /, *, lazy: bool = False): for func in funcs: self.insert(func) - def insert(self, provider: Callable[..., Any]): + def insert(self, provider: Callable[..., Any]) -> None: key = get_type_hints(provider)['return'] if (origin := get_origin(key)) is not None: args = get_args(key) @@ -77,17 +77,19 @@ def insert(self, provider: Callable[..., Any]): Return = typing.TypeVar("Return") def call(self, func: Callable[..., Return], bound: Optional[Any] = None) -> Return: - print('call', func, bound) tps = get_type_hints(func) del tps['return'] args: Dict[str, Any] = {} for name, tp in tps.items(): - args[name] = self._get(tp, bound=bound) + if isinstance(tp, TypeVar): + tp = tp if bound is None else bound + elif (origin := get_origin(tp)) is not None: + if isinstance(get_args(tp)[0], TypeVar): + tp = origin[bound] + args[name] = self._get(tp) return func(**args) - def _get(self, tp, bound: Optional[type] = None): - if (cached := self._cache.get(tp)) is not None: - return cached + def _get(self, tp: Type[T], /) -> Delayed: # When building a workflow, there are two common problems: # # 1. Intermediate results are used more than once. @@ -100,30 +102,20 @@ def _get(self, tp, bound: Optional[type] = None): # unless it is marked as lazy. We therefore use singleton-scope (to ensure Dask # will recognize the task as the same object) and also wrap the function in # dask.delayed. - if (provider := self._providers.get(tp)) is not None: - result = self.call(provider, bound) - self._cache[tp] = result - return result - elif (origin := get_origin(tp)) is None: - if (provider := self._providers.get(bound)) is not None: - result = self.call(provider) - self._cache[bound] = result - return result + if tp in self._providers: + key = tp + bound = None + elif (origin := get_origin(tp)) in self._providers: + key = origin + bound = get_args(tp)[0] else: - if (provider := self._providers.get(origin)) is not None: - # TODO We would really need to support multiple bound params properly - param = get_args(tp)[0] - result = self.call( - provider, bound if isinstance(param, TypeVar) else param - ) - self._cache[tp] = result - return result - else: - provider = self._providers[origin[bound]] - result = self.call(provider, bound) - self._cache[origin[bound]] = result - return result - raise UnsatisfiedRequirement("No provider found for type", tp) + raise UnsatisfiedRequirement("No provider found for type", tp) + if (cached := self._cache.get(key)) is not None: + return cached + provider = self._providers.get(key) + result = self.call(provider, bound) + self._cache[tp] = result + return result def get(self, tp: Type[T], /) -> Union[T, Delayed]: # We are slightly abusing Python's type system here, by using the From aed0367db9d88e2119fdba632f0b242342811d42 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 11 Jul 2023 14:09:57 +0200 Subject: [PATCH 15/43] Some mypy fixes --- src/sciline/__init__.py | 2 ++ src/sciline/container.py | 2 +- src/sciline/domain.py | 2 +- tests/complex_workflow_test.py | 6 ++---- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py index 07ac722f..d94d8695 100644 --- a/src/sciline/__init__.py +++ b/src/sciline/__init__.py @@ -11,3 +11,5 @@ from .container import Container, UnsatisfiedRequirement from .domain import Scope + +__all__ = ["Container", "Scope", "UnsatisfiedRequirement"] diff --git a/src/sciline/container.py b/src/sciline/container.py index 36317988..acd10a6d 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -112,7 +112,7 @@ def _get(self, tp: Type[T], /) -> Delayed: raise UnsatisfiedRequirement("No provider found for type", tp) if (cached := self._cache.get(key)) is not None: return cached - provider = self._providers.get(key) + provider = self._providers[key] result = self.call(provider, bound) self._cache[tp] = result return result diff --git a/src/sciline/domain.py b/src/sciline/domain.py index a79727eb..fbd0f924 100644 --- a/src/sciline/domain.py +++ b/src/sciline/domain.py @@ -6,5 +6,5 @@ class Scope(Generic[T]): - def __new__(cls, x): + def __new__(cls, x): # type: ignore return x diff --git a/tests/complex_workflow_test.py b/tests/complex_workflow_test.py index 768de501..2fba102e 100644 --- a/tests/complex_workflow_test.py +++ b/tests/complex_workflow_test.py @@ -86,13 +86,11 @@ def iofq( def raw_sample() -> Raw[SampleRun]: - return Raw[SampleRun](RawData(data=np.ones(4), monitor1=1.0, monitor2=2.0)) + return Raw(RawData(data=np.ones(4), monitor1=1.0, monitor2=2.0)) def raw_background() -> Raw[BackgroundRun]: - return Raw[BackgroundRun]( - RawData(data=np.ones(4) * 1.5, monitor1=1.0, monitor2=4.0) - ) + return Raw(RawData(data=np.ones(4) * 1.5, monitor1=1.0, monitor2=4.0)) def detector_mask() -> DetectorMask: From d96d0222212af15afd6d464acfb3175884d8818a Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 11 Jul 2023 14:10:30 +0200 Subject: [PATCH 16/43] Remove unused tests --- tests/domain_test.py | 17 ---- tests/mypy_test.py | 194 ------------------------------------------- 2 files changed, 211 deletions(-) delete mode 100644 tests/domain_test.py delete mode 100644 tests/mypy_test.py diff --git a/tests/domain_test.py b/tests/domain_test.py deleted file mode 100644 index 45937cd0..00000000 --- a/tests/domain_test.py +++ /dev/null @@ -1,17 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -import sciline as sl -import typing - - -def test_domain_type_factory(): - Str = sl.DomainTypeFactory('Str', str) - tp = Str('a', b=1) - assert tp is tp - assert tp == tp - - -def test_NewGenericType(): - T = typing.TypeVar('T') - Str = sl.domain.NewGenericType[T]('Str', str) - assert Str[int] == Str[int] diff --git a/tests/mypy_test.py b/tests/mypy_test.py deleted file mode 100644 index 4cd25236..00000000 --- a/tests/mypy_test.py +++ /dev/null @@ -1,194 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from __future__ import annotations - -import typing - -import numpy as np - -import sciline as sl - - -def factory() -> typing.Callable: - T = typing.TypeVar('T') - - def func(x: T) -> typing.List[T]: - return [x, x] - - return func - - -def test_factory(): - f = factory() - typing.get_type_hints(f)['return'] - assert f(1) == [1, 1] - - -def test_providers(): - providers = {} - f = factory() - for tp in (int, float): - generic = typing.get_type_hints(f)['return'] - # replace the typevar with the concrete type - concrete = generic[tp] - providers[concrete] = f - - def func(arg: typing.List[int]) -> int: - return sum(arg) - - assert func(providers.get(typing.get_type_hints(func)['arg'])(1)) == 2 - - -T = typing.TypeVar('T') - - -# Str = sl.domain.MyNewType('Str', str) -# Str = typing.NewType('Str', str) - - -class Str(typing.Generic[T]): - def __init__(self, a): - self.__a = a - - def __getattr__(self, attr): - return getattr(self.__a, attr) - - def __setattr__(self, attr, val): - if attr == '_Str__a': - object.__setattr__(self, attr, val) - - return setattr(self.__a, attr, val) - - -def str_factory() -> typing.Callable: - T = typing.TypeVar('T') - - def func(x: T) -> Str[T]: - return Str(f'{x}') - - return func - - -def test_Str(): - f = str_factory() - assert f(1) == '1' - - -def make_domain_type( - base, types: typing.Tuple[type, ...] -) -> typing.Dict[type, typing.NewType]: - return {tp: typing.NewType(f'IofQ_{tp.__name__}', base) for tp in types} - - -def test_make_domain_type(): - Raw = make_domain_type(list, (int, float)) - IofQ = make_domain_type(str, (int, float)) - - T = typing.TypeVar('T') - - def func(x: Raw[T]) -> IofQ[T]: - return f'{x}' - - assert func(Raw[int]([1, 2])) == '[1, 2]' - assert func(1.0) == '1.0' - - -def test_wrapping(): - T = typing.TypeVar('T') - - class Raw(typing.Generic[T]): - def __init__(self, value: list): - self.value: list = value - - class IofQ(typing.Generic[T]): - def __init__(self, value: str): - self.value: str = value - - def factory(tp: type) -> typing.Callable: - class DomainType(typing.Generic[T]): - def __init__(self, value: tp): - self.value: str = value - - def func(x: Raw[T]) -> IofQ[T]: - return IofQ[T](f'{x.value}') - - assert func(Raw[int]([1])).value == '[1]' - assert func(Raw[float]([1.0])).value == '[1.0]' - - -from typing import ( - Any, - Callable, - Dict, - Generic, - Optional, - TypeVar, - get_args, - get_origin, - get_type_hints, -) - - -class SingleParameterGeneric(Generic[T]): - def __new__(cls, x: np.ndarray): - assert isinstance(x, np.ndarray) - return x - - -class DomainType(SingleParameterGeneric[T], np.ndarray): - ... - - -class AnotherDomainType(SingleParameterGeneric[T], np.ndarray): - ... - - -DataType = TypeVar("DataType") - - -def foo(data: DomainType[DataType]) -> AnotherDomainType[DataType]: - return AnotherDomainType(data + 1) - - -def test_foo() -> None: - assert np.array_equal(foo(DomainType(np.array([1, 2, 3]))), [1, 2, 3]) - a = np.array([1, 2, 3]) - assert DomainType(a) is a - - Array = typing.NewType("Array", np.ndarray) - assert Array(a) is a - - -def func(x: AnotherDomainType[int]) -> int: - return np.sum(x) - - -def make_int() -> DomainType[int]: - return DomainType(np.array([1, 2, 3])) - - -def test_injection() -> None: - providers: Dict[type, Callable[..., Any]] = { - int: func, - DomainType[int]: make_int, - AnotherDomainType: foo, - } - - Return = typing.TypeVar("Return") - - def call(func: Callable[..., Return], bound: Optional[Any] = None) -> Return: - tps = get_type_hints(func) - del tps['return'] - args: Dict[str, Any] = {} - for name, tp in tps.items(): - if (provider := providers.get(tp)) is not None: - args[name] = call(provider, bound) - elif (origin := get_origin(tp)) is not None: - if (provider := providers.get(origin)) is not None: - args[name] = call(provider, *get_args(tp)) - else: - provider = providers[origin[bound]] - args[name] = call(provider, bound) - return func(**args) - - assert call(func) == 9 From 4750732157f94273e82663e641b55c207966c255 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 11 Jul 2023 14:15:55 +0200 Subject: [PATCH 17/43] Drop 'injector' dependency --- pyproject.toml | 1 - requirements/base.in | 1 - requirements/base.txt | 14 ++++------ requirements/ci.txt | 16 +++++------ requirements/docs.txt | 59 ++++++++++++++++++++++++++-------------- requirements/static.txt | 8 +++--- requirements/test.txt | 6 ++-- src/sciline/container.py | 7 ++--- 8 files changed, 61 insertions(+), 51 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b92ea0b8..5487b956 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,6 @@ classifiers = [ requires-python = ">=3.8" dependencies = [ "dask", - "injector", ] dynamic = ["version"] diff --git a/requirements/base.in b/requirements/base.in index 43292dbf..b2034ba3 100644 --- a/requirements/base.in +++ b/requirements/base.in @@ -1,2 +1 @@ dask -injector diff --git a/requirements/base.txt b/requirements/base.txt index 235c6275..7aeaacdd 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,22 +1,20 @@ -# SHA1:436d53b37d30cc5f2a62b6ed730bcf46f1762d7b +# SHA1:b153fb0209b1ab0c52d1b98a28b498e4dc303498 # # This file is autogenerated by pip-compile-multi # To update, run: # # pip-compile-multi # -click==8.1.3 +click==8.1.4 # via dask cloudpickle==2.2.1 # via dask dask==2023.5.0 # via -r base.in -fsspec==2023.5.0 +fsspec==2023.6.0 # via dask -importlib-metadata==6.6.0 +importlib-metadata==6.8.0 # via dask -injector==0.20.1 - # via -r base.in locket==1.0.0 # via partd packaging==23.1 @@ -29,7 +27,5 @@ toolz==0.12.0 # via # dask # partd -typing-extensions==4.6.3 - # via injector -zipp==3.15.0 +zipp==3.16.0 # via importlib-metadata diff --git a/requirements/ci.txt b/requirements/ci.txt index f17a6bc6..beb13154 100644 --- a/requirements/ci.txt +++ b/requirements/ci.txt @@ -11,19 +11,19 @@ certifi==2023.5.7 # via requests chardet==5.1.0 # via tox -charset-normalizer==3.1.0 +charset-normalizer==3.2.0 # via requests colorama==0.4.6 # via tox distlib==0.3.6 # via virtualenv -filelock==3.12.0 +filelock==3.12.2 # via # tox # virtualenv gitdb==4.0.10 # via gitpython -gitpython==3.1.31 +gitpython==3.1.32 # via -r ci.in idna==3.4 # via requests @@ -32,13 +32,13 @@ packaging==23.1 # -r ci.in # pyproject-api # tox -platformdirs==3.5.1 +platformdirs==3.8.1 # via # tox # virtualenv -pluggy==1.0.0 +pluggy==1.2.0 # via tox -pyproject-api==1.5.1 +pyproject-api==1.5.3 # via tox requests==2.31.0 # via -r ci.in @@ -48,9 +48,9 @@ tomli==2.0.1 # via # pyproject-api # tox -tox==4.6.0 +tox==4.6.4 # via -r ci.in urllib3==2.0.3 # via requests -virtualenv==20.23.0 +virtualenv==20.23.1 # via tox diff --git a/requirements/docs.txt b/requirements/docs.txt index 3b558460..0ad8c12d 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,4 +1,4 @@ -# SHA1:58bb49d5210e2885ee4c5c48167d217012b5b723 +# SHA1:c8d387b90560a2db42c26e33735a4f78ada867c3 # # This file is autogenerated by pip-compile-multi # To update, run: @@ -13,8 +13,10 @@ alabaster==0.7.13 asttokens==2.2.1 # via stack-data attrs==23.1.0 - # via jsonschema -autodoc-pydantic==1.8.0 + # via + # jsonschema + # referencing +autodoc-pydantic==1.9.0 # via -r docs.in babel==2.12.1 # via @@ -30,7 +32,7 @@ bleach==6.0.0 # via nbconvert certifi==2023.5.7 # via requests -charset-normalizer==3.1.0 +charset-normalizer==3.2.0 # via requests comm==0.1.3 # via ipykernel @@ -54,9 +56,11 @@ idna==3.4 # via requests imagesize==1.4.1 # via sphinx -importlib-resources==5.12.0 - # via jsonschema -ipykernel==6.23.1 +importlib-resources==6.0.0 + # via + # jsonschema + # jsonschema-specifications +ipykernel==6.24.0 # via -r docs.in ipython==8.12.2 # via @@ -70,13 +74,15 @@ jinja2==3.1.2 # nbconvert # nbsphinx # sphinx -jsonschema==4.17.3 +jsonschema==4.18.0 # via nbformat -jupyter-client==8.2.0 +jsonschema-specifications==2023.6.1 + # via jsonschema +jupyter-client==8.3.0 # via # ipykernel # nbclient -jupyter-core==5.3.0 +jupyter-core==5.3.1 # via # ipykernel # jupyter-client @@ -85,7 +91,7 @@ jupyter-core==5.3.0 # nbformat jupyterlab-pygments==0.2.2 # via nbconvert -markdown-it-py==2.2.0 +markdown-it-py==3.0.0 # via # mdit-py-plugins # myst-parser @@ -97,19 +103,19 @@ matplotlib-inline==0.1.6 # via # ipykernel # ipython -mdit-py-plugins==0.3.5 +mdit-py-plugins==0.4.0 # via myst-parser mdurl==0.1.2 # via markdown-it-py -mistune==2.0.5 +mistune==3.0.1 # via nbconvert -myst-parser==1.0.0 +myst-parser==2.0.0 # via -r docs.in nbclient==0.8.0 # via nbconvert -nbconvert==7.4.0 +nbconvert==7.6.0 # via nbsphinx -nbformat==5.9.0 +nbformat==5.9.1 # via # nbclient # nbconvert @@ -128,9 +134,9 @@ pickleshare==0.7.5 # via ipython pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==3.5.1 +platformdirs==3.8.1 # via jupyter-core -prompt-toolkit==3.0.38 +prompt-toolkit==3.0.39 # via ipython psutil==5.9.5 # via ipykernel @@ -138,7 +144,7 @@ ptyprocess==0.7.0 # via pexpect pure-eval==0.2.2 # via stack-data -pydantic==1.10.9 +pydantic==1.10.11 # via autodoc-pydantic pydata-sphinx-theme==0.13.3 # via -r docs.in @@ -149,8 +155,6 @@ pygments==2.15.1 # nbconvert # pydata-sphinx-theme # sphinx -pyrsistent==0.19.3 - # via jsonschema python-dateutil==2.8.2 # via jupyter-client pytz==2023.3 @@ -159,8 +163,16 @@ pyzmq==25.1.0 # via # ipykernel # jupyter-client +referencing==0.29.1 + # via + # jsonschema + # jsonschema-specifications requests==2.31.0 # via sphinx +rpds-py==0.8.10 + # via + # jsonschema + # referencing six==1.16.0 # via # asttokens @@ -218,6 +230,11 @@ traitlets==5.9.0 # nbconvert # nbformat # nbsphinx +typing-extensions==4.7.1 + # via + # ipython + # pydantic + # pydata-sphinx-theme urllib3==2.0.3 # via requests wcwidth==0.2.6 diff --git a/requirements/static.txt b/requirements/static.txt index 1c04a6da..e587652e 100644 --- a/requirements/static.txt +++ b/requirements/static.txt @@ -9,19 +9,19 @@ cfgv==3.3.1 # via pre-commit distlib==0.3.6 # via virtualenv -filelock==3.12.0 +filelock==3.12.2 # via virtualenv identify==2.5.24 # via pre-commit nodeenv==1.8.0 # via pre-commit -platformdirs==3.5.1 +platformdirs==3.8.1 # via virtualenv -pre-commit==3.3.2 +pre-commit==3.3.3 # via -r static.in pyyaml==6.0 # via pre-commit -virtualenv==20.23.0 +virtualenv==20.23.1 # via pre-commit # The following packages are considered to be unsafe in a requirements file: diff --git a/requirements/test.txt b/requirements/test.txt index 320e3ac2..f589f618 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -6,15 +6,15 @@ # pip-compile-multi # -r base.txt -exceptiongroup==1.1.1 +exceptiongroup==1.1.2 # via pytest iniconfig==2.0.0 # via pytest numpy==1.24.4 # via -r test.in -pluggy==1.0.0 +pluggy==1.2.0 # via pytest -pytest==7.3.1 +pytest==7.4.0 # via -r test.in tomli==2.0.1 # via pytest diff --git a/src/sciline/container.py b/src/sciline/container.py index acd10a6d..6053585f 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -26,7 +26,7 @@ def _delayed(func: Callable) -> Callable: Decorator to make a function return a delayed object. In contrast to dask.delayed, this uses functools.wraps, to preserve the - type hints, which is a prerequisite for injector to work. + type hints, which is a prerequisite for injecting args based on their type hints. """ import dask @@ -119,9 +119,8 @@ def _get(self, tp: Type[T], /) -> Delayed: def get(self, tp: Type[T], /) -> Union[T, Delayed]: # We are slightly abusing Python's type system here, by using the - # injector to get T, but actually it returns a Delayed that can - # compute T. self._injector does not know this due to how we setup the - # bindings. We'd like to use Delayed[T], but that is not supported yet: + # self._get to get T, but actually it returns a Delayed that can + # compute T. We'd like to use Delayed[T], but that is not supported yet: # https://github.com/dask/dask/pull/9256 task: Delayed = self._get(tp) # type: ignore return task if self._lazy else task.compute() From c1a9b0707c3854db6d07055911e3c6b69c1732b2 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 11 Jul 2023 14:23:38 +0200 Subject: [PATCH 18/43] Cleanup --- src/sciline/container.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 6053585f..3b3a0ce7 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -57,17 +57,17 @@ def __init__(self, funcs: List[Callable], /, *, lazy: bool = False): If True, the functions are wrapped in :py:func:`dask.delayed` before being injected. This allows to build a dask graph from the container. """ - self._providers: Dict[type, Callable[..., Any]] = {} + self._providers: Dict[type, Callable] = {} self._lazy: bool = lazy self._cache: Dict[type, Any] = {} for func in funcs: self.insert(func) - def insert(self, provider: Callable[..., Any]) -> None: + def insert(self, provider: Callable) -> None: key = get_type_hints(provider)['return'] if (origin := get_origin(key)) is not None: args = get_args(key) - if len(args) != 1: + if len(args) != 1 and any(isinstance(arg, TypeVar) for arg in args): raise ValueError(f'Cannot handle {key} with more than 1 argument') key = origin if isinstance(args[0], TypeVar) else key if key in self._providers: From 35ce11be28f8dbc9c653ea7c5f442292d365df86 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 11 Jul 2023 15:00:00 +0200 Subject: [PATCH 19/43] Remove 'lazy' argumentin favor of dedicated methods to fix mypy --- src/sciline/container.py | 22 +++++++++------------- tests/complex_workflow_test.py | 11 +++++++---- tests/container_test.py | 27 +++++++++++++-------------- 3 files changed, 29 insertions(+), 31 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 3b3a0ce7..6a3887e0 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -12,7 +12,6 @@ Optional, Type, TypeVar, - Union, get_args, get_origin, get_type_hints, @@ -45,7 +44,7 @@ class UnsatisfiedRequirement(Exception): class Container: - def __init__(self, funcs: List[Callable], /, *, lazy: bool = False): + def __init__(self, funcs: List[Callable], /): """ Create a :py:class:`Container` from a list of functions. @@ -53,12 +52,8 @@ def __init__(self, funcs: List[Callable], /, *, lazy: bool = False): ---------- funcs: List of functions to be injected. Must be annotated with type hints. - lazy: - If True, the functions are wrapped in :py:func:`dask.delayed` before - being injected. This allows to build a dask graph from the container. """ self._providers: Dict[type, Callable] = {} - self._lazy: bool = lazy self._cache: Dict[type, Any] = {} for func in funcs: self.insert(func) @@ -97,11 +92,9 @@ def _get(self, tp: Type[T], /) -> Delayed: # in memory longer than necessary. # # To address these problems, we can internally build a graph of tasks, instead - # of # directly creating dependencies between functions. Currently we use Dask - # for this. # The Container instance will automatically compute the task, - # unless it is marked as lazy. We therefore use singleton-scope (to ensure Dask - # will recognize the task as the same object) and also wrap the function in - # dask.delayed. + # of directly creating dependencies between functions. Currently we use Dask + # for this. We cache call results to ensure Dask will recognize the task + # as the same object) and also wrap the function in dask.delayed. if tp in self._providers: key = tp bound = None @@ -117,10 +110,13 @@ def _get(self, tp: Type[T], /) -> Delayed: self._cache[tp] = result return result - def get(self, tp: Type[T], /) -> Union[T, Delayed]: + def get(self, tp: Type[T], /) -> Delayed: # We are slightly abusing Python's type system here, by using the # self._get to get T, but actually it returns a Delayed that can # compute T. We'd like to use Delayed[T], but that is not supported yet: # https://github.com/dask/dask/pull/9256 task: Delayed = self._get(tp) # type: ignore - return task if self._lazy else task.compute() + return task + + def compute(self, tp: Type[T], /) -> T: + return self.get(tp).compute() diff --git a/tests/complex_workflow_test.py b/tests/complex_workflow_test.py index 2fba102e..886e3926 100644 --- a/tests/complex_workflow_test.py +++ b/tests/complex_workflow_test.py @@ -112,6 +112,7 @@ def subtract_background( def test_reduction_workflow(): + # See https://github.com/python/mypy/issues/14661 container = sl.Container( [ raw_sample, @@ -121,9 +122,11 @@ def test_reduction_workflow(): direct_beam, subtract_background, ] - + reduction + + reduction # type: ignore ) - assert np.array_equal(container.get(IofQ[SampleRun]), [3, 6, 0, 24]) - assert np.array_equal(container.get(IofQ[BackgroundRun]), [9, 18, 0, 72]) - assert np.array_equal(container.get(BackgroundSubtractedIofQ), [-6, -12, 0, -48]) + assert np.array_equal(container.compute(IofQ[SampleRun]), [3, 6, 0, 24]) + assert np.array_equal(container.compute(IofQ[BackgroundRun]), [9, 18, 0, 72]) + assert np.array_equal( + container.compute(BackgroundSubtractedIofQ), [-6, -12, 0, -48] + ) diff --git a/tests/container_test.py b/tests/container_test.py index 9b8711e5..a42149c2 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -25,17 +25,17 @@ def int_float_to_str(x: int, y: float) -> str: def test_make_container_sets_up_working_container(): container = sl.Container([int_to_float, make_int]) - assert container.get(float) == 1.5 - assert container.get(int) == 3 + assert container.compute(float) == 1.5 + assert container.compute(int) == 3 def test_make_container_does_not_autobind(): container = sl.Container([int_to_float]) with pytest.raises(sl.UnsatisfiedRequirement): - container.get(float) + container.compute(float) -def test_intermediate_computed_once_when_not_lazy(): +def test_intermediate_computed_once(): ncall = 0 def provide_int() -> int: @@ -43,19 +43,19 @@ def provide_int() -> int: ncall += 1 return 3 - container = sl.Container([int_to_float, provide_int, int_float_to_str], lazy=False) - assert container.get(str) == "3;1.5" + container = sl.Container([int_to_float, provide_int, int_float_to_str]) + assert container.compute(str) == "3;1.5" assert ncall == 1 -def test_make_container_lazy_returns_task_that_computes_result(): - container = sl.Container([int_to_float, make_int], lazy=True) +def test_get_returns_task_that_computes_result(): + container = sl.Container([int_to_float, make_int]) task = container.get(float) assert hasattr(task, 'compute') assert task.compute() == 1.5 -def test_lazy_with_multiple_outputs_computes_intermediates_once(): +def test_multiple_get_calls_can_be_computed_without_repeated_calls(): ncall = 0 def provide_int() -> int: @@ -63,7 +63,7 @@ def provide_int() -> int: ncall += 1 return 3 - container = sl.Container([int_to_float, provide_int, int_float_to_str], lazy=True) + container = sl.Container([int_to_float, provide_int, int_float_to_str]) task1 = container.get(float) task2 = container.get(str) assert dask.compute(task1, task2) == (1.5, '3;1.5') @@ -104,7 +104,6 @@ def use_strings(s1: Str[Run1], s2: Str[Run2]) -> Result: container = sl.Container( [provide_int, float1, float2, use_strings, int_float_to_str], - lazy=True, ) assert container.get(Result).compute() == "3;1.5;3;2.5" assert ncall == 1 @@ -129,6 +128,6 @@ def combine(x: Str[int], y: Str[float]) -> str: return f"{x};{y}" container = sl.Container([make_int, make_float, combine, f]) - assert container.get(Str[int]) == '3' - assert container.get(Str[float]) == '1.5' - assert container.get(str) == '3;1.5' + assert container.compute(Str[int]) == '3' + assert container.compute(Str[float]) == '1.5' + assert container.compute(str) == '3;1.5' From ca9d62bc686dec0aa631ca9de4b7be5752008b04 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 11 Jul 2023 15:22:40 +0200 Subject: [PATCH 20/43] Fix more mypy --- src/sciline/container.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 6a3887e0..d94bcc9f 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -2,7 +2,6 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from __future__ import annotations -import typing from functools import wraps from typing import ( Any, @@ -17,21 +16,21 @@ get_type_hints, ) -from dask.delayed import Delayed +from dask.delayed import Delayed, delayed -def _delayed(func: Callable) -> Callable: +def _delayed(func: Callable[..., Any]) -> Callable[..., Delayed]: """ Decorator to make a function return a delayed object. In contrast to dask.delayed, this uses functools.wraps, to preserve the type hints, which is a prerequisite for injecting args based on their type hints. """ - import dask @wraps(func) - def wrapper(*args, **kwargs): - return dask.delayed(func)(*args, **kwargs) + def wrapper(*args: Any, **kwargs: Any) -> Delayed: + task: Delayed = delayed(func)(*args, **kwargs) + return task return wrapper @@ -44,7 +43,7 @@ class UnsatisfiedRequirement(Exception): class Container: - def __init__(self, funcs: List[Callable], /): + def __init__(self, funcs: List[Callable[..., Any]], /): """ Create a :py:class:`Container` from a list of functions. @@ -53,12 +52,12 @@ def __init__(self, funcs: List[Callable], /): funcs: List of functions to be injected. Must be annotated with type hints. """ - self._providers: Dict[type, Callable] = {} - self._cache: Dict[type, Any] = {} + self._providers: Dict[type, Callable[..., Any]] = {} + self._cache: Dict[type, Delayed] = {} for func in funcs: self.insert(func) - def insert(self, provider: Callable) -> None: + def insert(self, provider: Callable[..., Any]) -> None: key = get_type_hints(provider)['return'] if (origin := get_origin(key)) is not None: args = get_args(key) @@ -69,9 +68,9 @@ def insert(self, provider: Callable) -> None: raise ValueError(f'Provider for {key} already exists') self._providers[key] = _delayed(provider) - Return = typing.TypeVar("Return") - - def call(self, func: Callable[..., Return], bound: Optional[Any] = None) -> Return: + def _call( + self, func: Callable[..., Delayed], bound: Optional[Any] = None + ) -> Delayed: tps = get_type_hints(func) del tps['return'] args: Dict[str, Any] = {} @@ -106,7 +105,7 @@ def _get(self, tp: Type[T], /) -> Delayed: if (cached := self._cache.get(key)) is not None: return cached provider = self._providers[key] - result = self.call(provider, bound) + result = self._call(provider, bound) self._cache[tp] = result return result @@ -115,8 +114,9 @@ def get(self, tp: Type[T], /) -> Delayed: # self._get to get T, but actually it returns a Delayed that can # compute T. We'd like to use Delayed[T], but that is not supported yet: # https://github.com/dask/dask/pull/9256 - task: Delayed = self._get(tp) # type: ignore - return task + return self._get(tp) def compute(self, tp: Type[T], /) -> T: - return self.get(tp).compute() + task = self.get(tp) + result: T = task.compute() # type: ignore + return result From 2cd0c62adb983135eed9fd24306e6a35f9af4552 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 08:01:49 +0200 Subject: [PATCH 21/43] Begin refactor to track individual TypeVar bindings --- src/sciline/container.py | 31 +++++++++++++++++----- tests/container_test.py | 56 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 8 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index d94bcc9f..50fd8521 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -8,7 +8,6 @@ Callable, Dict, List, - Optional, Type, TypeVar, get_args, @@ -69,17 +68,25 @@ def insert(self, provider: Callable[..., Any]) -> None: self._providers[key] = _delayed(provider) def _call( - self, func: Callable[..., Delayed], bound: Optional[Any] = None + self, func: Callable[..., Delayed], bound: Dict[TypeVar, Any] | None = None ) -> Delayed: + bound = bound or {} tps = get_type_hints(func) del tps['return'] args: Dict[str, Any] = {} for name, tp in tps.items(): if isinstance(tp, TypeVar): - tp = tp if bound is None else bound + tp = bound[tp] elif (origin := get_origin(tp)) is not None: - if isinstance(get_args(tp)[0], TypeVar): - tp = origin[bound] + if any(isinstance(arg, TypeVar) for arg in get_args(tp)): + # replace all TypeVar with bound types + tp = origin[ + tuple( + bound[arg] if isinstance(arg, TypeVar) else arg + for arg in get_args(tp) + ) + ] + # tp = origin[bound] args[name] = self._get(tp) return func(**args) @@ -96,15 +103,25 @@ def _get(self, tp: Type[T], /) -> Delayed: # as the same object) and also wrap the function in dask.delayed. if tp in self._providers: key = tp - bound = None + direct = True + # bound = None elif (origin := get_origin(tp)) in self._providers: key = origin - bound = get_args(tp)[0] + direct = False + # bound = get_args(tp)[0] else: raise UnsatisfiedRequirement("No provider found for type", tp) + # TODO Is using `key` correct here? Maybe need to also use `bound`? if (cached := self._cache.get(key)) is not None: return cached provider = self._providers[key] + bound: Dict[TypeVar, Any] = {} + if not direct: + hints = get_type_hints(provider)['return'] + for requested, provided in zip(get_args(tp), get_args(hints)): + if isinstance(provided, TypeVar): + bound[provided] = requested + result = self._call(provider, bound) self._cache[tp] = result return result diff --git a/tests/container_test.py b/tests/container_test.py index a42149c2..ba6a83e7 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import NewType, TypeVar +from typing import Generic, NewType, TypeVar, get_origin import dask import pytest @@ -131,3 +131,57 @@ def combine(x: Str[int], y: Str[float]) -> str: assert container.compute(Str[int]) == '3' assert container.compute(Str[float]) == '1.5' assert container.compute(str) == '3;1.5' + + +T1 = TypeVar('T1') +T2 = TypeVar('T2') + + +class SingleArg(Generic[T1]): + ... + + +class MultiArg(Generic[T1, T2]): + ... + + +def test_understanding_of_Generic(): + assert get_origin(MultiArg) is None + with pytest.raises(TypeError): + MultiArg[int] # to few parameters + assert get_origin(MultiArg[int, T2]) is MultiArg + assert get_origin(MultiArg[T1, T2]) is MultiArg + + +def test_understanding_of_TypeVar(): + assert T1 != T2 + assert T1 == T1 + assert T1 is T1 + assert TypeVar('T3') != TypeVar('T3') + + +def test_TypeVars_params_are_not_associated_unless_they_match(): + T1 = TypeVar('T1') + T2 = TypeVar('T2') + + class A(Generic[T1]): + ... + + class B(Generic[T2]): + ... + + def source() -> A[int]: + return A[int]() + + def not_matching(x: A[T1]) -> B[T2]: + return B[T2]() + + def matching(x: A[T1]) -> B[T1]: + return B[T1]() + + container = sl.Container([source, not_matching]) + with pytest.raises(KeyError): + container.compute(B[int]) + + container = sl.Container([source, matching]) + container.compute(B[int]) From 0e14c130f313d583f65558f586034a6a98b59797 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 08:26:31 +0200 Subject: [PATCH 22/43] Cleanup and more testing --- src/sciline/container.py | 32 ++++++++-------------- tests/container_test.py | 59 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 21 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 50fd8521..74111e53 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -60,17 +60,12 @@ def insert(self, provider: Callable[..., Any]) -> None: key = get_type_hints(provider)['return'] if (origin := get_origin(key)) is not None: args = get_args(key) - if len(args) != 1 and any(isinstance(arg, TypeVar) for arg in args): - raise ValueError(f'Cannot handle {key} with more than 1 argument') - key = origin if isinstance(args[0], TypeVar) else key + key = origin if any(isinstance(arg, TypeVar) for arg in args) else key if key in self._providers: raise ValueError(f'Provider for {key} already exists') self._providers[key] = _delayed(provider) - def _call( - self, func: Callable[..., Delayed], bound: Dict[TypeVar, Any] | None = None - ) -> Delayed: - bound = bound or {} + def _call(self, func: Callable[..., Delayed], bound: Dict[TypeVar, Any]) -> Delayed: tps = get_type_hints(func) del tps['return'] args: Dict[str, Any] = {} @@ -78,15 +73,12 @@ def _call( if isinstance(tp, TypeVar): tp = bound[tp] elif (origin := get_origin(tp)) is not None: - if any(isinstance(arg, TypeVar) for arg in get_args(tp)): - # replace all TypeVar with bound types - tp = origin[ - tuple( - bound[arg] if isinstance(arg, TypeVar) else arg - for arg in get_args(tp) - ) - ] - # tp = origin[bound] + tp = origin[ + tuple( + bound[arg] if isinstance(arg, TypeVar) else arg + for arg in get_args(tp) + ) + ] args[name] = self._get(tp) return func(**args) @@ -101,19 +93,17 @@ def _get(self, tp: Type[T], /) -> Delayed: # of directly creating dependencies between functions. Currently we use Dask # for this. We cache call results to ensure Dask will recognize the task # as the same object) and also wrap the function in dask.delayed. + if (cached := self._cache.get(tp)) is not None: + return cached if tp in self._providers: key = tp direct = True - # bound = None elif (origin := get_origin(tp)) in self._providers: key = origin direct = False - # bound = get_args(tp)[0] else: raise UnsatisfiedRequirement("No provider found for type", tp) - # TODO Is using `key` correct here? Maybe need to also use `bound`? - if (cached := self._cache.get(key)) is not None: - return cached + provider = self._providers[key] bound: Dict[TypeVar, Any] = {} if not direct: diff --git a/tests/container_test.py b/tests/container_test.py index ba6a83e7..0f540a5d 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from dataclasses import dataclass from typing import Generic, NewType, TypeVar, get_origin import dask @@ -185,3 +186,61 @@ def matching(x: A[T1]) -> B[T1]: container = sl.Container([source, matching]) container.compute(B[int]) + + +def test_multi_Generic_with_fully_bound_arguments(): + T1 = TypeVar('T1') + T2 = TypeVar('T2') + + @dataclass + class A(Generic[T1, T2]): + first: T1 + second: T2 + + def source() -> A[int, float]: + return A[int, float](1, 2.0) + + container = sl.Container([source]) + assert container.compute(A[int, float]) == A[int, float](1, 2.0) + + +def test_multi_Generic_with_partially_bound_arguments(): + T1 = TypeVar('T1') + T2 = TypeVar('T2') + + @dataclass + class A(Generic[T1, T2]): + first: T1 + second: T2 + + def source() -> float: + return 2.0 + + def partially_bound(x: T1) -> A[int, T1]: + return A[int, T1](1, x) + + container = sl.Container([source, partially_bound]) + assert container.compute(A[int, float]) == A[int, float](1, 2.0) + + +def test_multi_Generic_with_multiple_unbound(): + T1 = TypeVar('T1') + T2 = TypeVar('T2') + + @dataclass + class A(Generic[T1, T2]): + first: T1 + second: T2 + + def int_source() -> int: + return 1 + + def float_source() -> float: + return 2.0 + + def unbound(x: T1, y: T2) -> A[T1, T2]: + return A[T1, T2](x, y) + + container = sl.Container([int_source, float_source, unbound]) + assert container.compute(A[int, float]) == A[int, float](1, 2.0) + assert container.compute(A[float, int]) == A[float, int](2.0, 1) From 07e08b0858b0f2f45f75f2b945af16b4ded4ffc6 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 10:05:51 +0200 Subject: [PATCH 23/43] Support multiple partially bound instances of same generic --- src/sciline/container.py | 41 +++++++++++++++++++++++++++++++--------- tests/container_test.py | 41 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 9 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 74111e53..ab1cf2e6 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -41,6 +41,17 @@ class UnsatisfiedRequirement(Exception): pass +def _is_compatible_type_tuple( + requested: tuple[type, ...], provided: tuple[type | TypeVar, ...] +) -> bool: + for req, prov in zip(requested, provided): + if isinstance(prov, TypeVar): + continue + if req != prov: + return False + return True + + class Container: def __init__(self, funcs: List[Callable[..., Any]], /): """ @@ -58,12 +69,20 @@ def __init__(self, funcs: List[Callable[..., Any]], /): def insert(self, provider: Callable[..., Any]) -> None: key = get_type_hints(provider)['return'] - if (origin := get_origin(key)) is not None: - args = get_args(key) - key = origin if any(isinstance(arg, TypeVar) for arg in args) else key - if key in self._providers: + parametrized = get_origin(key) is not None + # if (origin := get_origin(key)) is not None: + + # key = origin if any(isinstance(arg, TypeVar) for arg in args) else key + if (not parametrized) and (key in self._providers): raise ValueError(f'Provider for {key} already exists') - self._providers[key] = _delayed(provider) + if parametrized: + subproviders = self._providers.setdefault(get_origin(key), {}) + args = get_args(key) + if args in subproviders: + raise ValueError(f'Provider for {key} already exists') + subproviders[args] = _delayed(provider) + else: + self._providers[key] = _delayed(provider) def _call(self, func: Callable[..., Delayed], bound: Dict[TypeVar, Any]) -> Delayed: tps = get_type_hints(func) @@ -107,10 +126,14 @@ def _get(self, tp: Type[T], /) -> Delayed: provider = self._providers[key] bound: Dict[TypeVar, Any] = {} if not direct: - hints = get_type_hints(provider)['return'] - for requested, provided in zip(get_args(tp), get_args(hints)): - if isinstance(provided, TypeVar): - bound[provided] = requested + requested = get_args(tp) + for args, subprovider in provider.items(): + if _is_compatible_type_tuple(requested, args): + provider = subprovider + bound = dict(zip(args, requested)) + break + else: + raise UnsatisfiedRequirement("No provider found for type", tp) result = self._call(provider, bound) self._cache[tp] = result diff --git a/tests/container_test.py b/tests/container_test.py index 0f540a5d..31a79959 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -244,3 +244,44 @@ def unbound(x: T1, y: T2) -> A[T1, T2]: container = sl.Container([int_source, float_source, unbound]) assert container.compute(A[int, float]) == A[int, float](1, 2.0) assert container.compute(A[float, int]) == A[float, int](2.0, 1) + + +def test_distinct_fully_bound_instances_yield_distinct_results(): + T1 = TypeVar('T1') + + @dataclass + class A(Generic[T1]): + value: T1 + + def int_source() -> A[int]: + return A[int](1) + + def float_source() -> A[float]: + return A[float](2.0) + + container = sl.Container([int_source, float_source]) + assert container.compute(A[int]) == A[int](1) + assert container.compute(A[float]) == A[float](2.0) + + +def test_distinct_partially_bound_instances_yield_distinct_results(): + T1 = TypeVar('T1') + + @dataclass + class A(Generic[T1, T2]): + first: T1 + second: T2 + + def str_source() -> str: + return 'a' + + def int_source(x: T1) -> A[int, T1]: + return A[int, T1](1, x) + + def float_source(x: T1) -> A[float, T1]: + return A[float, T1](2.0, x) + + container = sl.Container([str_source, int_source, float_source]) + print(list(container._providers)) + assert container.compute(A[int, str]) == A[int, str](1, 'a') + assert container.compute(A[float, str]) == A[float, str](2.0, 'a') From 6115194c146f3599f08c9e5c329c8d41cc907edd Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 10:18:07 +0200 Subject: [PATCH 24/43] Raise if multiple providers match --- src/sciline/container.py | 14 ++++++++++---- tests/container_test.py | 28 +++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index ab1cf2e6..e11b1afb 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -127,13 +127,19 @@ def _get(self, tp: Type[T], /) -> Delayed: bound: Dict[TypeVar, Any] = {} if not direct: requested = get_args(tp) + matches = [] for args, subprovider in provider.items(): if _is_compatible_type_tuple(requested, args): - provider = subprovider - bound = dict(zip(args, requested)) - break - else: + matches.append((args, subprovider)) + if len(matches) == 0: raise UnsatisfiedRequirement("No provider found for type", tp) + elif len(matches) > 1: + raise KeyError( + "Multiple providers found for type", tp, "with arguments", requested + ) + args, subprovider = matches[0] + provider = subprovider + bound = dict(zip(args, requested)) result = self._call(provider, bound) self._cache[tp] = result diff --git a/tests/container_test.py b/tests/container_test.py index 31a79959..28901c78 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -282,6 +282,32 @@ def float_source(x: T1) -> A[float, T1]: return A[float, T1](2.0, x) container = sl.Container([str_source, int_source, float_source]) - print(list(container._providers)) assert container.compute(A[int, str]) == A[int, str](1, 'a') assert container.compute(A[float, str]) == A[float, str](2.0, 'a') + + +def test_multiple_matching_partial_providers_raises(): + T1 = TypeVar('T1') + + @dataclass + class A(Generic[T1, T2]): + first: T1 + second: T2 + + def int_source() -> int: + return 1 + + def float_source() -> float: + return 2.0 + + def provider1(x: T1) -> A[int, T1]: + return A[int, T1](1, x) + + def provider2(x: T2) -> A[T2, float]: + return A[T2, float](x, 2.0) + + container = sl.Container([int_source, float_source, provider1, provider2]) + assert container.compute(A[int, int]) == A[int, int](1, 1) + assert container.compute(A[float, float]) == A[float, float](2.0, 2.0) + with pytest.raises(KeyError): + container.compute(A[int, float]) From 54f8bc4ae501e25e49efc92ef6690d5cfc844545 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 10:42:48 +0200 Subject: [PATCH 25/43] Refactor for passin mypy --- src/sciline/container.py | 56 ++++++++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index e11b1afb..3ff9c86c 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -8,6 +8,7 @@ Callable, Dict, List, + Tuple, Type, TypeVar, get_args, @@ -44,6 +45,11 @@ class UnsatisfiedRequirement(Exception): def _is_compatible_type_tuple( requested: tuple[type, ...], provided: tuple[type | TypeVar, ...] ) -> bool: + """ + Check if a tuple of requested types is compatible with a tuple of provided types. + + Types in the tuples must either by equal, or the provided type must be a TypeVar. + """ for req, prov in zip(requested, provided): if isinstance(prov, TypeVar): continue @@ -52,8 +58,11 @@ def _is_compatible_type_tuple( return True +Provider = Callable[..., Any] + + class Container: - def __init__(self, funcs: List[Callable[..., Any]], /): + def __init__(self, funcs: List[Provider], /): """ Create a :py:class:`Container` from a list of functions. @@ -62,21 +71,21 @@ def __init__(self, funcs: List[Callable[..., Any]], /): funcs: List of functions to be injected. Must be annotated with type hints. """ - self._providers: Dict[type, Callable[..., Any]] = {} + self._providers: Dict[type, Provider] = {} + self._generic_providers: Dict[ + type, Dict[Tuple[type | TypeVar, ...], Provider] + ] = {} self._cache: Dict[type, Delayed] = {} for func in funcs: self.insert(func) - def insert(self, provider: Callable[..., Any]) -> None: + def insert(self, provider: Provider) -> None: key = get_type_hints(provider)['return'] - parametrized = get_origin(key) is not None - # if (origin := get_origin(key)) is not None: - - # key = origin if any(isinstance(arg, TypeVar) for arg in args) else key - if (not parametrized) and (key in self._providers): + generic = get_origin(key) is not None + if (not generic) and (key in self._providers): raise ValueError(f'Provider for {key} already exists') - if parametrized: - subproviders = self._providers.setdefault(get_origin(key), {}) + if generic: + subproviders = self._generic_providers.setdefault(get_origin(key), {}) args = get_args(key) if args in subproviders: raise ValueError(f'Provider for {key} already exists') @@ -116,21 +125,24 @@ def _get(self, tp: Type[T], /) -> Delayed: return cached if tp in self._providers: key = tp - direct = True - elif (origin := get_origin(tp)) in self._providers: + generic = False + elif (origin := get_origin(tp)) in self._generic_providers: key = origin - direct = False + generic = True else: raise UnsatisfiedRequirement("No provider found for type", tp) - provider = self._providers[key] bound: Dict[TypeVar, Any] = {} - if not direct: + if not generic: + provider = self._providers[key] + else: + providers = self._generic_providers[key] requested = get_args(tp) - matches = [] - for args, subprovider in provider.items(): - if _is_compatible_type_tuple(requested, args): - matches.append((args, subprovider)) + matches = [ + (args, subprovider) + for args, subprovider in providers.items() + if _is_compatible_type_tuple(requested, args) + ] if len(matches) == 0: raise UnsatisfiedRequirement("No provider found for type", tp) elif len(matches) > 1: @@ -139,7 +151,11 @@ def _get(self, tp: Type[T], /) -> Delayed: ) args, subprovider = matches[0] provider = subprovider - bound = dict(zip(args, requested)) + bound = { + arg: req + for arg, req in zip(args, requested) + if isinstance(arg, TypeVar) + } result = self._call(provider, bound) self._cache[tp] = result From 6e1549b946f8108fb879b73b734b8aabdb86fc16 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 10:45:12 +0200 Subject: [PATCH 26/43] Cleanup tests --- tests/container_test.py | 31 +++---------------------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/tests/container_test.py b/tests/container_test.py index 28901c78..710ea36b 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from dataclasses import dataclass -from typing import Generic, NewType, TypeVar, get_origin +from typing import Generic, NewType, TypeVar import dask import pytest @@ -134,33 +134,6 @@ def combine(x: Str[int], y: Str[float]) -> str: assert container.compute(str) == '3;1.5' -T1 = TypeVar('T1') -T2 = TypeVar('T2') - - -class SingleArg(Generic[T1]): - ... - - -class MultiArg(Generic[T1, T2]): - ... - - -def test_understanding_of_Generic(): - assert get_origin(MultiArg) is None - with pytest.raises(TypeError): - MultiArg[int] # to few parameters - assert get_origin(MultiArg[int, T2]) is MultiArg - assert get_origin(MultiArg[T1, T2]) is MultiArg - - -def test_understanding_of_TypeVar(): - assert T1 != T2 - assert T1 == T1 - assert T1 is T1 - assert TypeVar('T3') != TypeVar('T3') - - def test_TypeVars_params_are_not_associated_unless_they_match(): T1 = TypeVar('T1') T2 = TypeVar('T2') @@ -266,6 +239,7 @@ def float_source() -> A[float]: def test_distinct_partially_bound_instances_yield_distinct_results(): T1 = TypeVar('T1') + T2 = TypeVar('T2') @dataclass class A(Generic[T1, T2]): @@ -288,6 +262,7 @@ def float_source(x: T1) -> A[float, T1]: def test_multiple_matching_partial_providers_raises(): T1 = TypeVar('T1') + T2 = TypeVar('T2') @dataclass class A(Generic[T1, T2]): From 63266638484517afb151a7296b3aaa6b4a50d54a Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 10:58:22 +0200 Subject: [PATCH 27/43] Cleanup --- src/sciline/container.py | 35 +++++++++++++---------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 3ff9c86c..d194d10f 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -81,16 +81,15 @@ def __init__(self, funcs: List[Provider], /): def insert(self, provider: Provider) -> None: key = get_type_hints(provider)['return'] - generic = get_origin(key) is not None - if (not generic) and (key in self._providers): - raise ValueError(f'Provider for {key} already exists') - if generic: + if get_origin(key) is not None: subproviders = self._generic_providers.setdefault(get_origin(key), {}) args = get_args(key) if args in subproviders: raise ValueError(f'Provider for {key} already exists') subproviders[args] = _delayed(provider) else: + if key in self._providers: + raise ValueError(f'Provider for {key} already exists') self._providers[key] = _delayed(provider) def _call(self, func: Callable[..., Delayed], bound: Dict[TypeVar, Any]) -> Delayed: @@ -123,32 +122,22 @@ def _get(self, tp: Type[T], /) -> Delayed: # as the same object) and also wrap the function in dask.delayed. if (cached := self._cache.get(tp)) is not None: return cached - if tp in self._providers: - key = tp - generic = False - elif (origin := get_origin(tp)) in self._generic_providers: - key = origin - generic = True - else: - raise UnsatisfiedRequirement("No provider found for type", tp) - bound: Dict[TypeVar, Any] = {} - if not generic: - provider = self._providers[key] - else: - providers = self._generic_providers[key] + if (provider := self._providers.get(tp)) is not None: + result = self._call(provider, {}) + elif (origin := get_origin(tp)) is not None and ( + subproviders := self._generic_providers[origin] + ) is not None: requested = get_args(tp) matches = [ (args, subprovider) - for args, subprovider in providers.items() + for args, subprovider in subproviders.items() if _is_compatible_type_tuple(requested, args) ] if len(matches) == 0: raise UnsatisfiedRequirement("No provider found for type", tp) elif len(matches) > 1: - raise KeyError( - "Multiple providers found for type", tp, "with arguments", requested - ) + raise KeyError("Multiple providers found for type", tp) args, subprovider = matches[0] provider = subprovider bound = { @@ -156,8 +145,10 @@ def _get(self, tp: Type[T], /) -> Delayed: for arg, req in zip(args, requested) if isinstance(arg, TypeVar) } + result = self._call(provider, bound) + else: + raise UnsatisfiedRequirement("No provider found for type", tp) - result = self._call(provider, bound) self._cache[tp] = result return result From 6457edfda56a0e7a8b7afee930a3d01395bc7783 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 11:04:10 +0200 Subject: [PATCH 28/43] Forbit providers providing 'None' --- src/sciline/container.py | 3 +++ tests/container_test.py | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/src/sciline/container.py b/src/sciline/container.py index d194d10f..82651189 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -3,6 +3,7 @@ from __future__ import annotations from functools import wraps +from types import NoneType from typing import ( Any, Callable, @@ -81,6 +82,8 @@ def __init__(self, funcs: List[Provider], /): def insert(self, provider: Provider) -> None: key = get_type_hints(provider)['return'] + if key == NoneType: + raise ValueError(f'Provider {provider} does not have a return type') if get_origin(key) is not None: subproviders = self._generic_providers.setdefault(get_origin(key), {}) args = get_args(key) diff --git a/tests/container_test.py b/tests/container_test.py index 710ea36b..bc3b95bb 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -134,6 +134,17 @@ def combine(x: Str[int], y: Str[float]) -> str: assert container.compute(str) == '3;1.5' +def test_inserting_provider_returning_None_raises(): + def provide_none() -> None: + return None + + with pytest.raises(ValueError): + sl.Container([provide_none]) + container = sl.Container([]) + with pytest.raises(ValueError): + container.insert(provide_none) + + def test_TypeVars_params_are_not_associated_unless_they_match(): T1 = TypeVar('T1') T2 = TypeVar('T2') From a5cea72f919810c72850a2c56c1ebb9418aba05b Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 11:15:50 +0200 Subject: [PATCH 29/43] Check for more errors --- src/sciline/container.py | 5 +++-- tests/container_test.py | 40 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 82651189..0275531e 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -81,9 +81,10 @@ def __init__(self, funcs: List[Provider], /): self.insert(func) def insert(self, provider: Provider) -> None: - key = get_type_hints(provider)['return'] + if (key := get_type_hints(provider).get('return')) is None: + raise ValueError(f'Provider {provider} lacks type-hint for return value') if key == NoneType: - raise ValueError(f'Provider {provider} does not have a return type') + raise ValueError(f'Provider {provider} returning `None` is not allowed') if get_origin(key) is not None: subproviders = self._generic_providers.setdefault(get_origin(key), {}) args = get_args(key) diff --git a/tests/container_test.py b/tests/container_test.py index bc3b95bb..779ea082 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from dataclasses import dataclass -from typing import Generic, NewType, TypeVar +from typing import Generic, List, NewType, TypeVar import dask import pytest @@ -145,6 +145,44 @@ def provide_none() -> None: container.insert(provide_none) +def test_inserting_provider_with_no_return_type_raises(): + def provide_none(): + return None + + with pytest.raises(ValueError): + sl.Container([provide_none]) + container = sl.Container([]) + with pytest.raises(ValueError): + container.insert(provide_none) + + +def test_typevar_requirement_of_provider_can_be_bound(): + T = TypeVar('T') + + def provider_int() -> int: + return 3 + + def provider(x: T) -> List[T]: + return [x, x] + + container = sl.Container([provider_int, provider]) + assert container.compute(List[int]) == [3, 3] + + +def test_unsatisfiable_typevar_requirement_of_provider_raises(): + T = TypeVar('T') + + def provider_int() -> int: + return 3 + + def provider(x: T) -> List[T]: + return [x, x] + + container = sl.Container([provider_int, provider]) + with pytest.raises(sl.UnsatisfiedRequirement): + container.compute(List[float]) + + def test_TypeVars_params_are_not_associated_unless_they_match(): T1 = TypeVar('T1') T2 = TypeVar('T2') From 0049206518ad53156f1b4375028eccd88ac669c5 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 11:28:23 +0200 Subject: [PATCH 30/43] Test resolving multiple typevars --- tests/container_test.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/container_test.py b/tests/container_test.py index 779ea082..88e28394 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -335,3 +335,44 @@ def provider2(x: T2) -> A[T2, float]: assert container.compute(A[float, float]) == A[float, float](2.0, 2.0) with pytest.raises(KeyError): container.compute(A[int, float]) + + +def test_TypeVar_params_track_to_multiple_sources(): + T1 = TypeVar('T1') + T2 = TypeVar('T2') + + @dataclass + class A(Generic[T1]): + value: T1 + + @dataclass + class B(Generic[T1]): + value: T1 + + @dataclass + class C(Generic[T1, T2]): + first: T1 + second: T2 + + def provide_int() -> int: + return 1 + + def provide_float() -> float: + return 2.0 + + def provide_A(x: T1) -> A[T1]: + return A[T1](x) + + # Note that it currently does not matter which TypeVar instance we use here: + # Container tracks uses of TypeVar within a single provider, but does not carry + # the information beyond the scope of a single call. + def provide_B(x: T1) -> B[T1]: + return B[T1](x) + + def provide_C(x: A[T1], y: B[T2]) -> C[T1, T2]: + return C[T1, T2](x.value, y.value) + + container = sl.Container( + [provide_int, provide_float, provide_A, provide_B, provide_C] + ) + assert container.compute(C[int, float]) == C[int, float](1, 2.0) From 91f2c911507bc443d8ec0006940d02fd29d5af6c Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 11:37:55 +0200 Subject: [PATCH 31/43] Use better exception type --- src/sciline/__init__.py | 4 ++-- src/sciline/container.py | 6 +++++- tests/container_test.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py index d94d8695..398c6314 100644 --- a/src/sciline/__init__.py +++ b/src/sciline/__init__.py @@ -9,7 +9,7 @@ except importlib.metadata.PackageNotFoundError: __version__ = "0.0.0" -from .container import Container, UnsatisfiedRequirement +from .container import AmbiguousProvider, Container, UnsatisfiedRequirement from .domain import Scope -__all__ = ["Container", "Scope", "UnsatisfiedRequirement"] +__all__ = ["AmbiguousProvider", "Container", "Scope", "UnsatisfiedRequirement"] diff --git a/src/sciline/container.py b/src/sciline/container.py index 0275531e..39212f1f 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -43,6 +43,10 @@ class UnsatisfiedRequirement(Exception): pass +class AmbiguousProvider(Exception): + pass + + def _is_compatible_type_tuple( requested: tuple[type, ...], provided: tuple[type | TypeVar, ...] ) -> bool: @@ -141,7 +145,7 @@ def _get(self, tp: Type[T], /) -> Delayed: if len(matches) == 0: raise UnsatisfiedRequirement("No provider found for type", tp) elif len(matches) > 1: - raise KeyError("Multiple providers found for type", tp) + raise AmbiguousProvider("Multiple providers found for type", tp) args, subprovider = matches[0] provider = subprovider bound = { diff --git a/tests/container_test.py b/tests/container_test.py index 88e28394..b46b83f7 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -333,7 +333,7 @@ def provider2(x: T2) -> A[T2, float]: container = sl.Container([int_source, float_source, provider1, provider2]) assert container.compute(A[int, int]) == A[int, int](1, 1) assert container.compute(A[float, float]) == A[float, float](2.0, 2.0) - with pytest.raises(KeyError): + with pytest.raises(sl.AmbiguousProvider): container.compute(A[int, float]) From 0732354147305383e0ba75917c38416a2040b474 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 11:45:11 +0200 Subject: [PATCH 32/43] NoneType is not available in older Python --- src/sciline/container.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 39212f1f..1e491101 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -3,7 +3,6 @@ from __future__ import annotations from functools import wraps -from types import NoneType from typing import ( Any, Callable, @@ -87,7 +86,8 @@ def __init__(self, funcs: List[Provider], /): def insert(self, provider: Provider) -> None: if (key := get_type_hints(provider).get('return')) is None: raise ValueError(f'Provider {provider} lacks type-hint for return value') - if key == NoneType: + # isinstance does not work here and types.NoneType available only in 3.10+ + if key == type(None): # noqa: E721 raise ValueError(f'Provider {provider} returning `None` is not allowed') if get_origin(key) is not None: subproviders = self._generic_providers.setdefault(get_origin(key), {}) From dc54e9b376d6655b13a590b71c503b147c9f8f44 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 13:34:34 +0200 Subject: [PATCH 33/43] Remove complication of eager wrapping in delayed --- src/sciline/container.py | 27 +++++---------------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 1e491101..a8a755c3 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -2,7 +2,6 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from __future__ import annotations -from functools import wraps from typing import ( Any, Callable, @@ -16,24 +15,8 @@ get_type_hints, ) -from dask.delayed import Delayed, delayed - - -def _delayed(func: Callable[..., Any]) -> Callable[..., Delayed]: - """ - Decorator to make a function return a delayed object. - - In contrast to dask.delayed, this uses functools.wraps, to preserve the - type hints, which is a prerequisite for injecting args based on their type hints. - """ - - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Delayed: - task: Delayed = delayed(func)(*args, **kwargs) - return task - - return wrapper - +import dask +from dask.delayed import Delayed T = TypeVar('T') @@ -94,11 +77,11 @@ def insert(self, provider: Provider) -> None: args = get_args(key) if args in subproviders: raise ValueError(f'Provider for {key} already exists') - subproviders[args] = _delayed(provider) + subproviders[args] = provider else: if key in self._providers: raise ValueError(f'Provider for {key} already exists') - self._providers[key] = _delayed(provider) + self._providers[key] = provider def _call(self, func: Callable[..., Delayed], bound: Dict[TypeVar, Any]) -> Delayed: tps = get_type_hints(func) @@ -115,7 +98,7 @@ def _call(self, func: Callable[..., Delayed], bound: Dict[TypeVar, Any]) -> Dela ) ] args[name] = self._get(tp) - return func(**args) + return dask.delayed(func)(**args) def _get(self, tp: Type[T], /) -> Delayed: # When building a workflow, there are two common problems: From e4b576f4d9b0b2613bb8562cc4b783e9667f136c Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 13:57:52 +0200 Subject: [PATCH 34/43] Make mypy happier --- src/sciline/container.py | 7 +++---- tests/container_test.py | 40 ++++++++++++++++++++-------------------- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index a8a755c3..c0a8767b 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -83,7 +83,7 @@ def insert(self, provider: Provider) -> None: raise ValueError(f'Provider for {key} already exists') self._providers[key] = provider - def _call(self, func: Callable[..., Delayed], bound: Dict[TypeVar, Any]) -> Delayed: + def _call(self, func: Callable[..., Any], bound: Dict[TypeVar, Any]) -> Delayed: tps = get_type_hints(func) del tps['return'] args: Dict[str, Any] = {} @@ -98,7 +98,7 @@ def _call(self, func: Callable[..., Delayed], bound: Dict[TypeVar, Any]) -> Dela ) ] args[name] = self._get(tp) - return dask.delayed(func)(**args) + return dask.delayed(func)(**args) # type: ignore def _get(self, tp: Type[T], /) -> Delayed: # When building a workflow, there are two common problems: @@ -129,8 +129,7 @@ def _get(self, tp: Type[T], /) -> Delayed: raise UnsatisfiedRequirement("No provider found for type", tp) elif len(matches) > 1: raise AmbiguousProvider("Multiple providers found for type", tp) - args, subprovider = matches[0] - provider = subprovider + args, provider = matches[0] bound = { arg: req for arg, req in zip(args, requested) diff --git a/tests/container_test.py b/tests/container_test.py index b46b83f7..15dc5c8a 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -24,19 +24,19 @@ def int_float_to_str(x: int, y: float) -> str: return f"{x};{y}" -def test_make_container_sets_up_working_container(): +def test_make_container_sets_up_working_container() -> None: container = sl.Container([int_to_float, make_int]) assert container.compute(float) == 1.5 assert container.compute(int) == 3 -def test_make_container_does_not_autobind(): +def test_make_container_does_not_autobind() -> None: container = sl.Container([int_to_float]) with pytest.raises(sl.UnsatisfiedRequirement): container.compute(float) -def test_intermediate_computed_once(): +def test_intermediate_computed_once() -> None: ncall = 0 def provide_int() -> int: @@ -49,14 +49,14 @@ def provide_int() -> int: assert ncall == 1 -def test_get_returns_task_that_computes_result(): +def test_get_returns_task_that_computes_result() -> None: container = sl.Container([int_to_float, make_int]) task = container.get(float) assert hasattr(task, 'compute') assert task.compute() == 1.5 -def test_multiple_get_calls_can_be_computed_without_repeated_calls(): +def test_multiple_get_calls_can_be_computed_without_repeated_calls() -> None: ncall = 0 def provide_int() -> int: @@ -71,7 +71,7 @@ def provide_int() -> int: assert ncall == 1 -def test_make_container_with_subgraph_template(): +def test_make_container_with_subgraph_template() -> None: ncall = 0 def provide_int() -> int: @@ -121,7 +121,7 @@ def f(x: Param) -> Str[Param]: return Str(f'{x}') -def test_container_from_templated(): +def test_container_from_templated() -> None: def make_float() -> float: return 1.5 @@ -134,7 +134,7 @@ def combine(x: Str[int], y: Str[float]) -> str: assert container.compute(str) == '3;1.5' -def test_inserting_provider_returning_None_raises(): +def test_inserting_provider_returning_None_raises() -> None: def provide_none() -> None: return None @@ -145,8 +145,8 @@ def provide_none() -> None: container.insert(provide_none) -def test_inserting_provider_with_no_return_type_raises(): - def provide_none(): +def test_inserting_provider_with_no_return_type_raises() -> None: + def provide_none(): # type: ignore return None with pytest.raises(ValueError): @@ -156,7 +156,7 @@ def provide_none(): container.insert(provide_none) -def test_typevar_requirement_of_provider_can_be_bound(): +def test_typevar_requirement_of_provider_can_be_bound() -> None: T = TypeVar('T') def provider_int() -> int: @@ -169,7 +169,7 @@ def provider(x: T) -> List[T]: assert container.compute(List[int]) == [3, 3] -def test_unsatisfiable_typevar_requirement_of_provider_raises(): +def test_unsatisfiable_typevar_requirement_of_provider_raises() -> None: T = TypeVar('T') def provider_int() -> int: @@ -183,7 +183,7 @@ def provider(x: T) -> List[T]: container.compute(List[float]) -def test_TypeVars_params_are_not_associated_unless_they_match(): +def test_TypeVars_params_are_not_associated_unless_they_match() -> None: T1 = TypeVar('T1') T2 = TypeVar('T2') @@ -210,7 +210,7 @@ def matching(x: A[T1]) -> B[T1]: container.compute(B[int]) -def test_multi_Generic_with_fully_bound_arguments(): +def test_multi_Generic_with_fully_bound_arguments() -> None: T1 = TypeVar('T1') T2 = TypeVar('T2') @@ -226,7 +226,7 @@ def source() -> A[int, float]: assert container.compute(A[int, float]) == A[int, float](1, 2.0) -def test_multi_Generic_with_partially_bound_arguments(): +def test_multi_Generic_with_partially_bound_arguments() -> None: T1 = TypeVar('T1') T2 = TypeVar('T2') @@ -245,7 +245,7 @@ def partially_bound(x: T1) -> A[int, T1]: assert container.compute(A[int, float]) == A[int, float](1, 2.0) -def test_multi_Generic_with_multiple_unbound(): +def test_multi_Generic_with_multiple_unbound() -> None: T1 = TypeVar('T1') T2 = TypeVar('T2') @@ -268,7 +268,7 @@ def unbound(x: T1, y: T2) -> A[T1, T2]: assert container.compute(A[float, int]) == A[float, int](2.0, 1) -def test_distinct_fully_bound_instances_yield_distinct_results(): +def test_distinct_fully_bound_instances_yield_distinct_results() -> None: T1 = TypeVar('T1') @dataclass @@ -286,7 +286,7 @@ def float_source() -> A[float]: assert container.compute(A[float]) == A[float](2.0) -def test_distinct_partially_bound_instances_yield_distinct_results(): +def test_distinct_partially_bound_instances_yield_distinct_results() -> None: T1 = TypeVar('T1') T2 = TypeVar('T2') @@ -309,7 +309,7 @@ def float_source(x: T1) -> A[float, T1]: assert container.compute(A[float, str]) == A[float, str](2.0, 'a') -def test_multiple_matching_partial_providers_raises(): +def test_multiple_matching_partial_providers_raises() -> None: T1 = TypeVar('T1') T2 = TypeVar('T2') @@ -337,7 +337,7 @@ def provider2(x: T2) -> A[T2, float]: container.compute(A[int, float]) -def test_TypeVar_params_track_to_multiple_sources(): +def test_TypeVar_params_track_to_multiple_sources() -> None: T1 = TypeVar('T1') T2 = TypeVar('T2') From 676e914e40f46210a8eadee7a1cb845ac436c963 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 14:15:51 +0200 Subject: [PATCH 35/43] Improve readability --- src/sciline/container.py | 26 +++++++++++++------------- tests/container_test.py | 2 +- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index c0a8767b..09f2e265 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -86,19 +86,19 @@ def insert(self, provider: Provider) -> None: def _call(self, func: Callable[..., Any], bound: Dict[TypeVar, Any]) -> Delayed: tps = get_type_hints(func) del tps['return'] - args: Dict[str, Any] = {} - for name, tp in tps.items(): - if isinstance(tp, TypeVar): - tp = bound[tp] - elif (origin := get_origin(tp)) is not None: - tp = origin[ - tuple( - bound[arg] if isinstance(arg, TypeVar) else arg - for arg in get_args(tp) - ) - ] - args[name] = self._get(tp) - return dask.delayed(func)(**args) # type: ignore + args = { + name: self._get(self._bind_free_typevars(tp, bound=bound)) + for name, tp in tps.items() + } + return dask.delayed(func)(**args) + + def _bind_free_typevars(self, tp: type, bound: Dict[TypeVar, Any]) -> type: + if isinstance(tp, TypeVar): + return bound[tp] + elif (origin := get_origin(tp)) is not None: + return origin[tuple(bound.get(a, a) for a in get_args(tp))] + else: + return tp def _get(self, tp: Type[T], /) -> Delayed: # When building a workflow, there are two common problems: diff --git a/tests/container_test.py b/tests/container_test.py index 15dc5c8a..576a2bd1 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -203,7 +203,7 @@ def matching(x: A[T1]) -> B[T1]: return B[T1]() container = sl.Container([source, not_matching]) - with pytest.raises(KeyError): + with pytest.raises(sl.UnsatisfiedRequirement): container.compute(B[int]) container = sl.Container([source, matching]) From 3d943d4357cd997b6de7897a1ad6f6016b54fa0e Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 14:34:19 +0200 Subject: [PATCH 36/43] Correct type hints --- src/sciline/container.py | 4 ++-- tests/container_test.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 09f2e265..08a22cb2 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -83,7 +83,7 @@ def insert(self, provider: Provider) -> None: raise ValueError(f'Provider for {key} already exists') self._providers[key] = provider - def _call(self, func: Callable[..., Any], bound: Dict[TypeVar, Any]) -> Delayed: + def _call(self, func: Callable[..., Any], bound: Dict[TypeVar, type]) -> Delayed: tps = get_type_hints(func) del tps['return'] args = { @@ -92,7 +92,7 @@ def _call(self, func: Callable[..., Any], bound: Dict[TypeVar, Any]) -> Delayed: } return dask.delayed(func)(**args) - def _bind_free_typevars(self, tp: type, bound: Dict[TypeVar, Any]) -> type: + def _bind_free_typevars(self, tp: type, bound: Dict[TypeVar, type]) -> type: if isinstance(tp, TypeVar): return bound[tp] elif (origin := get_origin(tp)) is not None: diff --git a/tests/container_test.py b/tests/container_test.py index 576a2bd1..2db52ac6 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -183,7 +183,7 @@ def provider(x: T) -> List[T]: container.compute(List[float]) -def test_TypeVars_params_are_not_associated_unless_they_match() -> None: +def test_TypeVar_params_are_not_associated_unless_they_match() -> None: T1 = TypeVar('T1') T2 = TypeVar('T2') From 7486d301a80ead660f799866b103baf3e0622ffd Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 14:39:05 +0200 Subject: [PATCH 37/43] More readability improvements --- src/sciline/container.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 08a22cb2..c8819757 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -45,6 +45,15 @@ def _is_compatible_type_tuple( return True +def _bind_free_typevars(tp: type, bound: Dict[TypeVar, type]) -> type: + if isinstance(tp, TypeVar): + return bound[tp] + elif (origin := get_origin(tp)) is not None: + return origin[tuple(bound.get(a, a) for a in get_args(tp))] + else: + return tp + + Provider = Callable[..., Any] @@ -59,9 +68,7 @@ def __init__(self, funcs: List[Provider], /): List of functions to be injected. Must be annotated with type hints. """ self._providers: Dict[type, Provider] = {} - self._generic_providers: Dict[ - type, Dict[Tuple[type | TypeVar, ...], Provider] - ] = {} + self._subproviders: Dict[type, Dict[Tuple[type | TypeVar, ...], Provider]] = {} self._cache: Dict[type, Delayed] = {} for func in funcs: self.insert(func) @@ -73,7 +80,7 @@ def insert(self, provider: Provider) -> None: if key == type(None): # noqa: E721 raise ValueError(f'Provider {provider} returning `None` is not allowed') if get_origin(key) is not None: - subproviders = self._generic_providers.setdefault(get_origin(key), {}) + subproviders = self._subproviders.setdefault(get_origin(key), {}) args = get_args(key) if args in subproviders: raise ValueError(f'Provider for {key} already exists') @@ -87,19 +94,11 @@ def _call(self, func: Callable[..., Any], bound: Dict[TypeVar, type]) -> Delayed tps = get_type_hints(func) del tps['return'] args = { - name: self._get(self._bind_free_typevars(tp, bound=bound)) + name: self._get(_bind_free_typevars(tp, bound=bound)) for name, tp in tps.items() } return dask.delayed(func)(**args) - def _bind_free_typevars(self, tp: type, bound: Dict[TypeVar, type]) -> type: - if isinstance(tp, TypeVar): - return bound[tp] - elif (origin := get_origin(tp)) is not None: - return origin[tuple(bound.get(a, a) for a in get_args(tp))] - else: - return tp - def _get(self, tp: Type[T], /) -> Delayed: # When building a workflow, there are two common problems: # @@ -116,9 +115,7 @@ def _get(self, tp: Type[T], /) -> Delayed: if (provider := self._providers.get(tp)) is not None: result = self._call(provider, {}) - elif (origin := get_origin(tp)) is not None and ( - subproviders := self._generic_providers[origin] - ) is not None: + elif (subproviders := self._subproviders.get(get_origin(tp))) is not None: requested = get_args(tp) matches = [ (args, subprovider) From 8bed356ac2dcb785cfd8b14ab35dc45982c380b5 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 15:16:52 +0200 Subject: [PATCH 38/43] Add UnboundTypeVar exception --- src/sciline/__init__.py | 15 +++++++++++++-- src/sciline/container.py | 10 ++++++++-- tests/container_test.py | 13 ++++++++++++- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py index 398c6314..4c52072d 100644 --- a/src/sciline/__init__.py +++ b/src/sciline/__init__.py @@ -9,7 +9,18 @@ except importlib.metadata.PackageNotFoundError: __version__ = "0.0.0" -from .container import AmbiguousProvider, Container, UnsatisfiedRequirement +from .container import ( + AmbiguousProvider, + Container, + UnboundTypeVar, + UnsatisfiedRequirement, +) from .domain import Scope -__all__ = ["AmbiguousProvider", "Container", "Scope", "UnsatisfiedRequirement"] +__all__ = [ + "AmbiguousProvider", + "Container", + "Scope", + "UnboundTypeVar", + "UnsatisfiedRequirement", +] diff --git a/src/sciline/container.py b/src/sciline/container.py index c8819757..7d54bf36 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -25,6 +25,10 @@ class UnsatisfiedRequirement(Exception): pass +class UnboundTypeVar(Exception): + pass + + class AmbiguousProvider(Exception): pass @@ -47,9 +51,11 @@ def _is_compatible_type_tuple( def _bind_free_typevars(tp: type, bound: Dict[TypeVar, type]) -> type: if isinstance(tp, TypeVar): - return bound[tp] + if (result := bound.get(tp)) is None: + raise UnboundTypeVar(f'Unbound type variable {tp}') + return result elif (origin := get_origin(tp)) is not None: - return origin[tuple(bound.get(a, a) for a in get_args(tp))] + return origin[tuple(_bind_free_typevars(arg, bound) for arg in get_args(tp))] else: return tp diff --git a/tests/container_test.py b/tests/container_test.py index 2db52ac6..dffd1e51 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -169,6 +169,17 @@ def provider(x: T) -> List[T]: assert container.compute(List[int]) == [3, 3] +def test_typevar_that_cannot_be_bound_raises_UnboundTypeVar() -> None: + T = TypeVar('T') + + def provider(_: T) -> int: + return 1 + + container = sl.Container([provider]) + with pytest.raises(sl.UnboundTypeVar): + container.compute(int) + + def test_unsatisfiable_typevar_requirement_of_provider_raises() -> None: T = TypeVar('T') @@ -203,7 +214,7 @@ def matching(x: A[T1]) -> B[T1]: return B[T1]() container = sl.Container([source, not_matching]) - with pytest.raises(sl.UnsatisfiedRequirement): + with pytest.raises(sl.UnboundTypeVar): container.compute(B[int]) container = sl.Container([source, matching]) From 5f1dc105c5485bb744f254d2381dae74918f5bc2 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Wed, 12 Jul 2023 15:25:57 +0200 Subject: [PATCH 39/43] Fix mypy --- src/sciline/container.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 7d54bf36..17be1c2d 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -85,8 +85,8 @@ def insert(self, provider: Provider) -> None: # isinstance does not work here and types.NoneType available only in 3.10+ if key == type(None): # noqa: E721 raise ValueError(f'Provider {provider} returning `None` is not allowed') - if get_origin(key) is not None: - subproviders = self._subproviders.setdefault(get_origin(key), {}) + if (origin := get_origin(key)) is not None: + subproviders = self._subproviders.setdefault(origin, {}) args = get_args(key) if args in subproviders: raise ValueError(f'Provider for {key} already exists') @@ -121,7 +121,9 @@ def _get(self, tp: Type[T], /) -> Delayed: if (provider := self._providers.get(tp)) is not None: result = self._call(provider, {}) - elif (subproviders := self._subproviders.get(get_origin(tp))) is not None: + elif (origin := get_origin(tp)) is not None and ( + subproviders := self._subproviders.get(origin) + ) is not None: requested = get_args(tp) matches = [ (args, subprovider) From f0e66b61bb0e12bce632cb83426e878cd2aa6246 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Thu, 13 Jul 2023 08:45:42 +0200 Subject: [PATCH 40/43] Pass mypy --- src/sciline/container.py | 12 ++++++++---- src/sciline/domain.py | 4 ++-- tests/complex_workflow_test.py | 19 ++++++++++--------- tests/container_test.py | 9 +++++---- tests/package_test.py | 2 +- 5 files changed, 26 insertions(+), 20 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 17be1c2d..764381a5 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -49,13 +49,16 @@ def _is_compatible_type_tuple( return True -def _bind_free_typevars(tp: type, bound: Dict[TypeVar, type]) -> type: +def _bind_free_typevars(tp: TypeVar | type, bound: Dict[TypeVar, type]) -> type: if isinstance(tp, TypeVar): if (result := bound.get(tp)) is None: raise UnboundTypeVar(f'Unbound type variable {tp}') return result elif (origin := get_origin(tp)) is not None: - return origin[tuple(_bind_free_typevars(arg, bound) for arg in get_args(tp))] + result = origin[tuple(_bind_free_typevars(arg, bound) for arg in get_args(tp))] + if not isinstance(result, type): + raise ValueError(f'Binding type variable {tp} resulted in non-type') + return result else: return tp @@ -103,7 +106,8 @@ def _call(self, func: Callable[..., Any], bound: Dict[TypeVar, type]) -> Delayed name: self._get(_bind_free_typevars(tp, bound=bound)) for name, tp in tps.items() } - return dask.delayed(func)(**args) + delayed = dask.delayed(func) # type: ignore[attr-defined] + return delayed(**args) # type: ignore[no-any-return] def _get(self, tp: Type[T], /) -> Delayed: # When building a workflow, there are two common problems: @@ -156,5 +160,5 @@ def get(self, tp: Type[T], /) -> Delayed: def compute(self, tp: Type[T], /) -> T: task = self.get(tp) - result: T = task.compute() # type: ignore + result: T = task.compute() # type: ignore[no-untyped-call] return result diff --git a/src/sciline/domain.py b/src/sciline/domain.py index fbd0f924..4117c26f 100644 --- a/src/sciline/domain.py +++ b/src/sciline/domain.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import Generic, TypeVar +from typing import Any, Generic, TypeVar T = TypeVar("T") class Scope(Generic[T]): - def __new__(cls, x): # type: ignore + def __new__(cls, x) -> Any: # type: ignore[no-untyped-def] return x diff --git a/tests/complex_workflow_test.py b/tests/complex_workflow_test.py index 886e3926..c576f7d6 100644 --- a/tests/complex_workflow_test.py +++ b/tests/complex_workflow_test.py @@ -5,6 +5,7 @@ import dask import numpy as np +import numpy.typing as npt import sciline as sl @@ -14,16 +15,16 @@ @dataclass class RawData: - data: np.ndarray + data: npt.NDArray[np.float64] monitor1: float monitor2: float SampleRun = NewType('SampleRun', int) BackgroundRun = NewType('BackgroundRun', int) -DetectorMask = NewType('DetectorMask', np.ndarray) -DirectBeam = NewType('DirectBeam', np.ndarray) -SolidAngle = NewType('SolidAngle', np.ndarray) +DetectorMask = NewType('DetectorMask', npt.NDArray[np.float64]) +DirectBeam = NewType('DirectBeam', npt.NDArray[np.float64]) +SolidAngle = NewType('SolidAngle', npt.NDArray[np.float64]) Run = TypeVar('Run') @@ -32,7 +33,7 @@ class Raw(sl.Scope[Run], RawData): ... -class Masked(sl.Scope[Run], np.ndarray): +class Masked(sl.Scope[Run], npt.NDArray[np.float64]): ... @@ -48,11 +49,11 @@ class TransmissionFraction(sl.Scope[Run], float): ... -class IofQ(sl.Scope[Run], np.ndarray): +class IofQ(sl.Scope[Run], npt.NDArray[np.float64]): ... -BackgroundSubtractedIofQ = NewType('BackgroundSubtractedIofQ', np.ndarray) +BackgroundSubtractedIofQ = NewType('BackgroundSubtractedIofQ', npt.NDArray[np.float64]) def incident_monitor(x: Raw[Run]) -> IncidentMonitor[Run]: @@ -111,7 +112,7 @@ def subtract_background( return BackgroundSubtractedIofQ(sample - background) -def test_reduction_workflow(): +def test_reduction_workflow() -> None: # See https://github.com/python/mypy/issues/14661 container = sl.Container( [ @@ -122,7 +123,7 @@ def test_reduction_workflow(): direct_beam, subtract_background, ] - + reduction # type: ignore + + reduction # type: ignore[arg-type] ) assert np.array_equal(container.compute(IofQ[SampleRun]), [3, 6, 0, 24]) diff --git a/tests/container_test.py b/tests/container_test.py index dffd1e51..82822d8a 100644 --- a/tests/container_test.py +++ b/tests/container_test.py @@ -53,7 +53,7 @@ def test_get_returns_task_that_computes_result() -> None: container = sl.Container([int_to_float, make_int]) task = container.get(float) assert hasattr(task, 'compute') - assert task.compute() == 1.5 + assert task.compute() == 1.5 # type: ignore[no-untyped-call] def test_multiple_get_calls_can_be_computed_without_repeated_calls() -> None: @@ -67,7 +67,7 @@ def provide_int() -> int: container = sl.Container([int_to_float, provide_int, int_float_to_str]) task1 = container.get(float) task2 = container.get(str) - assert dask.compute(task1, task2) == (1.5, '3;1.5') + assert dask.compute(task1, task2) == (1.5, '3;1.5') # type: ignore[attr-defined] assert ncall == 1 @@ -106,7 +106,8 @@ def use_strings(s1: Str[Run1], s2: Str[Run2]) -> Result: container = sl.Container( [provide_int, float1, float2, use_strings, int_float_to_str], ) - assert container.get(Result).compute() == "3;1.5;3;2.5" + task = container.get(Result) + assert task.compute() == "3;1.5;3;2.5" # type: ignore[no-untyped-call] assert ncall == 1 @@ -146,7 +147,7 @@ def provide_none() -> None: def test_inserting_provider_with_no_return_type_raises() -> None: - def provide_none(): # type: ignore + def provide_none(): # type: ignore[no-untyped-def] return None with pytest.raises(ValueError): diff --git a/tests/package_test.py b/tests/package_test.py index 426e378e..62ceb639 100644 --- a/tests/package_test.py +++ b/tests/package_test.py @@ -3,5 +3,5 @@ import sciline as pkg -def test_has_version(): +def test_has_version() -> None: assert hasattr(pkg, '__version__') From d9bf747c33d593f40bc5f2869b9469ec7153d3d1 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Thu, 13 Jul 2023 08:52:18 +0200 Subject: [PATCH 41/43] Fix runtime --- src/sciline/container.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 764381a5..7329c880 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -56,9 +56,9 @@ def _bind_free_typevars(tp: TypeVar | type, bound: Dict[TypeVar, type]) -> type: return result elif (origin := get_origin(tp)) is not None: result = origin[tuple(_bind_free_typevars(arg, bound) for arg in get_args(tp))] - if not isinstance(result, type): - raise ValueError(f'Binding type variable {tp} resulted in non-type') - return result + # This is a hack to make mypy happy. The type of result is actually + # typing._GenericAlias. + return result # type: ignore[return-value] else: return tp From 06096556cbfee40b41b38d9dfe6c880351ec5b38 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Thu, 13 Jul 2023 09:56:41 +0200 Subject: [PATCH 42/43] Fix mypy and tests --- src/sciline/container.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 7329c880..183deda5 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -56,9 +56,9 @@ def _bind_free_typevars(tp: TypeVar | type, bound: Dict[TypeVar, type]) -> type: return result elif (origin := get_origin(tp)) is not None: result = origin[tuple(_bind_free_typevars(arg, bound) for arg in get_args(tp))] - # This is a hack to make mypy happy. The type of result is actually - # typing._GenericAlias. - return result # type: ignore[return-value] + if result is None: + raise ValueError(f'Binding type variables in {tp} resulted in `None`') + return result else: return tp From 0363c3ad96484d88b9215bb8c50782c90b0efb28 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Thu, 13 Jul 2023 14:02:22 +0200 Subject: [PATCH 43/43] Small refactor for readability --- src/sciline/container.py | 78 ++++++++++++++++++++-------------------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/src/sciline/container.py b/src/sciline/container.py index 183deda5..dc1a5da5 100644 --- a/src/sciline/container.py +++ b/src/sciline/container.py @@ -99,32 +99,20 @@ def insert(self, provider: Provider) -> None: raise ValueError(f'Provider for {key} already exists') self._providers[key] = provider - def _call(self, func: Callable[..., Any], bound: Dict[TypeVar, type]) -> Delayed: - tps = get_type_hints(func) - del tps['return'] - args = { - name: self._get(_bind_free_typevars(tp, bound=bound)) - for name, tp in tps.items() + def _get_args( + self, func: Callable[..., Any], bound: Dict[TypeVar, type] + ) -> Dict[str, Delayed]: + return { + name: self.get(_bind_free_typevars(tp, bound=bound)) + for name, tp in get_type_hints(func).items() + if name != 'return' } - delayed = dask.delayed(func) # type: ignore[attr-defined] - return delayed(**args) # type: ignore[no-any-return] - - def _get(self, tp: Type[T], /) -> Delayed: - # When building a workflow, there are two common problems: - # - # 1. Intermediate results are used more than once. - # 2. Intermediate results are large, so we generally do not want to keep them - # in memory longer than necessary. - # - # To address these problems, we can internally build a graph of tasks, instead - # of directly creating dependencies between functions. Currently we use Dask - # for this. We cache call results to ensure Dask will recognize the task - # as the same object) and also wrap the function in dask.delayed. - if (cached := self._cache.get(tp)) is not None: - return cached + def _get_provider( + self, tp: Type[T] + ) -> Tuple[Callable[..., T], Dict[TypeVar, type]]: if (provider := self._providers.get(tp)) is not None: - result = self._call(provider, {}) + return provider, {} elif (origin := get_origin(tp)) is not None and ( subproviders := self._subproviders.get(origin) ) is not None: @@ -134,29 +122,41 @@ def _get(self, tp: Type[T], /) -> Delayed: for args, subprovider in subproviders.items() if _is_compatible_type_tuple(requested, args) ] - if len(matches) == 0: - raise UnsatisfiedRequirement("No provider found for type", tp) + if len(matches) == 1: + args, provider = matches[0] + bound = { + arg: req + for arg, req in zip(args, requested) + if isinstance(arg, TypeVar) + } + return provider, bound elif len(matches) > 1: raise AmbiguousProvider("Multiple providers found for type", tp) - args, provider = matches[0] - bound = { - arg: req - for arg, req in zip(args, requested) - if isinstance(arg, TypeVar) - } - result = self._call(provider, bound) - else: - raise UnsatisfiedRequirement("No provider found for type", tp) - - self._cache[tp] = result - return result + raise UnsatisfiedRequirement("No provider found for type", tp) def get(self, tp: Type[T], /) -> Delayed: # We are slightly abusing Python's type system here, by using the - # self._get to get T, but actually it returns a Delayed that can + # self.get to get T, but actually it returns a Delayed that can # compute T. We'd like to use Delayed[T], but that is not supported yet: # https://github.com/dask/dask/pull/9256 - return self._get(tp) + # + # When building a workflow, there are two common problems: + # + # 1. Intermediate results are used more than once. + # 2. Intermediate results are large, so we generally do not want to keep them + # in memory longer than necessary. + # + # To address these problems, we can internally build a graph of tasks, instead + # of directly creating dependencies between functions. Currently we use Dask + # for this. We cache call results to ensure Dask will recognize the task + # as the same object) and also wrap the function in dask.delayed. + if (cached := self._cache.get(tp)) is not None: + return cached + + provider, bound = self._get_provider(tp) + args = self._get_args(provider, bound=bound) + delayed = dask.delayed(provider) # type: ignore[attr-defined] + return self._cache.setdefault(tp, delayed(**args)) def compute(self, tp: Type[T], /) -> T: task = self.get(tp)