Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ADR 0001: Remove isinstance checks when setting params #160

Merged
merged 1 commit into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 0 additions & 31 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from ._provider import ArgSpec, Provider, ProviderLocation, ToProvider
from ._utils import key_name
from .display import pipeline_html_repr
from .domain import Scope, ScopeTwoParams
from .handler import (
ErrorHandler,
HandleAsBuildTimeException,
Expand Down Expand Up @@ -403,36 +402,6 @@ def __setitem__(self, key: Type[T], param: T) -> None:
param:
Concrete value to provide.
"""
# TODO Switch to isinstance(key, NewType) once our minimum is Python 3.10
# Note that we cannot pass mypy in Python<3.10 since NewType is not a type.
if hasattr(key, '__supertype__'):
underlying = key.__supertype__ # type: ignore[attr-defined]
else:
underlying = key
if (origin := get_origin(underlying)) is None:
# In Python 3.8, get_origin does not work with numpy.typing.NDArray,
# but it defines __origin__
if (np_origin := getattr(underlying, '__origin__', None)) is not None:
expected = np_origin
else:
expected = underlying
elif origin == Union:
expected = underlying
elif issubclass(origin, (Scope, ScopeTwoParams)):
scope = origin.__orig_bases__[0]
while (orig := get_origin(scope)) is not None and orig not in (
Scope,
ScopeTwoParams,
):
scope = orig.__orig_bases__[0]
expected = get_args(scope)[-1]
else:
expected = origin

if not isinstance(param, expected):
raise TypeError(
f'Key {key} incompatible to value {param} of type {type(param)}'
)
self._set_provider(key, Provider.parameter(param))

def set_param_table(self, params: ParamTable) -> None:
Expand Down
17 changes: 0 additions & 17 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,23 +561,6 @@ class A(Generic[T]):
assert pl.compute(A) == A(3)


def test_setitem_raises_TypeError_if_instance_does_not_match_key() -> None:
A = NewType('A', int)
T = TypeVar('T')

@dataclass
class B(Generic[T]):
value: T

pl = sl.Pipeline()
with pytest.raises(TypeError):
pl[int] = 1.0
with pytest.raises(TypeError):
pl[A] = 1.0
with pytest.raises(TypeError):
pl[B[int]] = 1.0


def test_setitem_can_replace_param_with_param() -> None:
pl = sl.Pipeline()
pl[int] = 1
Expand Down