Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better parametrized domain type #22

Merged
merged 8 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions src/sciline/domain.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from typing import Any, Generic, TypeVar
from typing import Any, Generic, TypeVar, get_args, get_origin

T = TypeVar("T")
PARAM = TypeVar("PARAM")
SUPER = TypeVar("SUPER")


class Scope(Generic[T]):
def __new__(cls, x) -> Any: # type: ignore[no-untyped-def]
class Scope(Generic[PARAM, SUPER]):
YooSunYoung marked this conversation as resolved.
Show resolved Hide resolved
def __init_subclass__(cls, **kwargs: Any) -> None:
# Mypy does not support __orig_bases__ yet(?)
# See also https://stackoverflow.com/a/73746554 for useful info
scope = cls.__orig_bases__[0] # type: ignore[attr-defined]
# Only check direct subclasses
if get_origin(scope) is Scope:
supertype = get_args(scope)[1]
# Remove potential generic params
# In Python 3.8, get_origin does not work with numpy.typing.NDArray,
# but it defines __origin__
supertype = getattr(supertype, '__origin__', None) or supertype
if supertype not in cls.__bases__:
raise TypeError(
f"Missing or wrong interface for {cls}, "
f"should inherit {supertype}.\n"
"Example:\n"
"\n"
" Param = TypeVar('Param')\n"
" \n"
" class A(sl.Scope[Param, float], float):\n"
" ...\n"
)
return super().__init_subclass__(**kwargs)

def __new__(cls, x: SUPER) -> SUPER: # type: ignore[misc]
return x
8 changes: 6 additions & 2 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,13 @@ def __setitem__(self, key: Type[T], param: T) -> None:
expected = np_origin
else:
expected = underlying
elif issubclass(origin, Scope):
scope = origin.__orig_bases__[0]
while (orig := get_origin(scope)) is not None and orig is not Scope:
scope = orig.__orig_bases__[0]
expected = get_args(scope)[1]
else:
# TODO This is probably quite brittle, maybe we can find a better way?
expected = origin.__bases__[1] if issubclass(origin, Scope) else origin
expected = origin

if not isinstance(param, expected):
raise TypeError(
Expand Down
13 changes: 7 additions & 6 deletions tests/complex_workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,28 @@ class RawData:
Run = TypeVar('Run')


class Raw(sl.Scope[Run], RawData):
# TODO Giving the base twice works with mypy, how can we avoid typing it twice?
class Raw(sl.Scope[Run, RawData], RawData):
...


class Masked(sl.Scope[Run], npt.NDArray[np.float64]):
class Masked(sl.Scope[Run, npt.NDArray[np.float64]], npt.NDArray[np.float64]):
...


class IncidentMonitor(sl.Scope[Run], float):
class IncidentMonitor(sl.Scope[Run, float], float):
...


class TransmissionMonitor(sl.Scope[Run], float):
class TransmissionMonitor(sl.Scope[Run, float], float):
...


class TransmissionFraction(sl.Scope[Run], float):
class TransmissionFraction(sl.Scope[Run, float], float):
...


class IofQ(sl.Scope[Run], npt.NDArray[np.float64]):
class IofQ(sl.Scope[Run, npt.NDArray[np.float64]], npt.NDArray[np.float64]):
...


Expand Down
49 changes: 49 additions & 0 deletions tests/domain_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from typing import NewType, TypeVar

import pytest

import sciline as sl

T = TypeVar("T")


def test_mypy_detects_wrong_arg_type_of_Scope_subclass() -> None:
Param = TypeVar('Param')
Param1 = NewType('Param1', int)

class A(sl.Scope[Param, float], float):
...

A[Param1](1.5)
A[Param1]('abc') # type: ignore[arg-type]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that all tests in this file pass at runtime. mypy complains if we remove any of the type: ignore comments, but also complains if we put too many, since our config ensures that unused ignore comments are not allowed.



def test_missing_interface_of_scope_subclass_raises() -> None:
param = TypeVar('param')

with pytest.raises(TypeError, match="Missing or wrong interface for"):

class A(sl.Scope[param, float]):
...


def test_mypy_accepts_interface_of_scope_sibling_class() -> None:
param = TypeVar('param')
param1 = NewType('param1', int)

class A(sl.Scope[param, float], float):
...

a = A[param1](1.5)
a + a


def test_inconsistent_type_and_interface_raises() -> None:
param = TypeVar('param')

with pytest.raises(TypeError, match="Missing or wrong interface for"):

class A(sl.Scope[param, str], float):
...
79 changes: 73 additions & 6 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass
from typing import Generic, List, NewType, TypeVar

import numpy as np
import numpy.typing as npt
import pytest

import sciline as sl
Expand Down Expand Up @@ -81,7 +83,7 @@ def func2(x: int) -> str:
def test_generic_providers_produce_use_dependencies_based_on_bound_typevar() -> None:
Param = TypeVar('Param')

class Str(sl.Scope[Param], str):
class Str(sl.Scope[Param, str], str):
...

def parametrized(x: Param) -> Str[Param]:
Expand All @@ -94,8 +96,8 @@ def combine(x: Str[int], y: Str[float]) -> str:
return f"{x};{y}"

pipeline = sl.Pipeline([make_int, make_float, combine, parametrized])
assert pipeline.compute(Str[int]) == '3'
assert pipeline.compute(Str[float]) == '1.5'
assert pipeline.compute(Str[int]) == Str[int]('3')
assert pipeline.compute(Str[float]) == Str[float]('1.5')
assert pipeline.compute(str) == '3;1.5'


Expand All @@ -109,10 +111,10 @@ def provide_int() -> int:

Param = TypeVar('Param')

class Float(sl.Scope[Param], float):
class Float(sl.Scope[Param, float], float):
...

class Str(sl.Scope[Param], str):
class Str(sl.Scope[Param, str], str):
...

def int_float_to_str(x: int, y: Float[Param]) -> Str[Param]:
Expand All @@ -138,6 +140,71 @@ def use_strings(s1: Str[Run1], s2: Str[Run2]) -> Result:
assert ncall == 1


def test_subclasses_of_generic_provider_defined_with_Scope_work() -> None:
Param = TypeVar('Param')

class StrT(sl.Scope[Param, str], str):
...

class Str1(StrT[Param]):
...

class Str2(StrT[Param]):
...

class Str3(StrT[Param]):
...

class Str4(Str3[Param]):
...

def make_str1() -> Str1[Param]:
return Str1('1')

def make_str2() -> Str2[Param]:
return Str2('2')

# Note that mypy cannot detect if when setting params, the type of the
# parameter does not match the key. Same problem as with NewType.
pipeline = sl.Pipeline(
[make_str1, make_str2],
params={
Str3[int]: Str3[int]('int3'),
Str3[float]: Str3[float]('float3'),
Str4[int]: Str2[int]('int4'),
},
)
assert pipeline.compute(Str1[float]) == Str1[float]('1')
assert pipeline.compute(Str2[float]) == Str2[float]('2')
assert pipeline.compute(Str3[int]) == Str3[int]('int3')
assert pipeline.compute(Str3[float]) == Str3[float]('float3')
assert pipeline.compute(Str4[int]) == Str4[int]('int4')


def test_subclasses_of_generic_array_provider_defined_with_Scope_work() -> None:
Param = TypeVar('Param')

class ArrayT(sl.Scope[Param, npt.NDArray[np.int64]], npt.NDArray[np.int64]):
...

class Array1(ArrayT[Param]):
...

class Array2(ArrayT[Param]):
...

def make_array1() -> Array1[Param]:
return Array1(np.array([1, 2, 3]))

def make_array2() -> Array2[Param]:
return Array2(np.array([4, 5, 6]))

pipeline = sl.Pipeline([make_array1, make_array2])
# Note that the param is not the dtype
assert np.all(pipeline.compute(Array1[str]) == np.array([1, 2, 3]))
assert np.all(pipeline.compute(Array2[str]) == np.array([4, 5, 6]))


def test_inserting_provider_returning_None_raises() -> None:
def provide_none() -> None:
return None
Expand Down Expand Up @@ -483,7 +550,7 @@ def func(x: int, y: float) -> str:
def test_init_with_sciline_Scope_subclass_param_works() -> None:
T = TypeVar('T')

class A(sl.Scope[T], int):
class A(sl.Scope[T, int], int):
...

pl = sl.Pipeline(params={A[float]: A(1), A[str]: A(2)})
Expand Down