From 714d4fa207f13988033c90427b3570d13d43a819 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 15 Sep 2023 08:57:05 +0200 Subject: [PATCH 1/2] Add tests --- tests/pipeline_test.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/pipeline_test.py b/tests/pipeline_test.py index cf1ef0f5..762a3796 100644 --- a/tests/pipeline_test.py +++ b/tests/pipeline_test.py @@ -81,6 +81,43 @@ def func2(x: int) -> str: assert ncall == 1 +def test_Scope_subclass_can_be_set_as_param(): + Param = TypeVar('Param') + + class Str(sl.Scope[Param, str], str): + ... + + pipeline = sl.Pipeline(params={Str[int]: Str[int]('1')}) + pipeline[Str[float]] = Str[float]('2.0') + assert pipeline.compute(Str[int]) == Str[int]('1') + assert pipeline.compute(Str[float]) == Str[float]('2.0') + + +def test_Scope_subclass_can_be_set_as_param_with_unbound_typevar(): + Param = TypeVar('Param') + + class Str(sl.Scope[Param, str], str): + ... + + pipeline = sl.Pipeline() + pipeline[Str[Param]] = Str[Param]('1') + assert pipeline.compute(Str[int]) == Str[int]('1') + assert pipeline.compute(Str[float]) == Str[float]('1') + + +def test_ScopeTwoParam_subclass_can_be_set_as_param(): + Param1 = TypeVar('Param1') + Param2 = TypeVar('Param2') + + class Str(sl.ScopeTwoParams[Param1, Param2, str], str): + ... + + pipeline = sl.Pipeline(params={Str[int, float]: Str[int, float]('1')}) + pipeline[Str[float, int]] = Str[float, int]('2.0') + assert pipeline.compute(Str[int, float]) == Str[int, float]('1') + assert pipeline.compute(Str[float, int]) == Str[float, int]('2.0') + + def test_generic_providers_produce_use_dependencies_based_on_bound_typevar() -> None: Param = TypeVar('Param') From 70fc8a1109f3821d37857ec79007604aac916f51 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 15 Sep 2023 09:03:47 +0200 Subject: [PATCH 2/2] Fix check preventing ScopeTwoParam subclass used as param --- src/sciline/pipeline.py | 11 +++++++---- tests/pipeline_test.py | 21 +++++++++++++++++---- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 648f92d4..4ed76328 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -26,7 +26,7 @@ from sciline.task_graph import TaskGraph -from .domain import Scope +from .domain import Scope, ScopeTwoParams from .param_table import ParamTable from .scheduler import Scheduler from .series import Series @@ -353,11 +353,14 @@ def __setitem__(self, key: Type[T], param: T) -> None: expected = np_origin else: expected = underlying - elif issubclass(origin, Scope): + elif issubclass(origin, (Scope, ScopeTwoParams)): scope = origin.__orig_bases__[0] - while (orig := get_origin(scope)) is not None and orig is not Scope: + while (orig := get_origin(scope)) is not None and orig not in ( + Scope, + ScopeTwoParams, + ): scope = orig.__orig_bases__[0] - expected = get_args(scope)[1] + expected = get_args(scope)[-1] else: expected = origin diff --git a/tests/pipeline_test.py b/tests/pipeline_test.py index 762a3796..afa7dc0e 100644 --- a/tests/pipeline_test.py +++ b/tests/pipeline_test.py @@ -81,7 +81,7 @@ def func2(x: int) -> str: assert ncall == 1 -def test_Scope_subclass_can_be_set_as_param(): +def test_Scope_subclass_can_be_set_as_param() -> None: Param = TypeVar('Param') class Str(sl.Scope[Param, str], str): @@ -93,19 +93,19 @@ class Str(sl.Scope[Param, str], str): assert pipeline.compute(Str[float]) == Str[float]('2.0') -def test_Scope_subclass_can_be_set_as_param_with_unbound_typevar(): +def test_Scope_subclass_can_be_set_as_param_with_unbound_typevar() -> None: Param = TypeVar('Param') class Str(sl.Scope[Param, str], str): ... pipeline = sl.Pipeline() - pipeline[Str[Param]] = Str[Param]('1') + pipeline[Str[Param]] = Str[Param]('1') # type: ignore[valid-type] assert pipeline.compute(Str[int]) == Str[int]('1') assert pipeline.compute(Str[float]) == Str[float]('1') -def test_ScopeTwoParam_subclass_can_be_set_as_param(): +def test_ScopeTwoParam_subclass_can_be_set_as_param() -> None: Param1 = TypeVar('Param1') Param2 = TypeVar('Param2') @@ -118,6 +118,19 @@ class Str(sl.ScopeTwoParams[Param1, Param2, str], str): assert pipeline.compute(Str[float, int]) == Str[float, int]('2.0') +def test_ScopeTwoParam_subclass_can_be_set_as_param_with_unbound_typevar() -> None: + Param1 = TypeVar('Param1') + Param2 = TypeVar('Param2') + + class Str(sl.ScopeTwoParams[Param1, Param2, str], str): + ... + + pipeline = sl.Pipeline() + pipeline[Str[Param1, Param2]] = Str[Param1, Param2]('1') # type: ignore[valid-type] + assert pipeline.compute(Str[int, float]) == Str[int, float]('1') + assert pipeline.compute(Str[float, int]) == Str[float, int]('1') + + def test_generic_providers_produce_use_dependencies_based_on_bound_typevar() -> None: Param = TypeVar('Param')