diff --git a/src/sciline/handler.py b/src/sciline/handler.py index a9954465..a8890fbb 100644 --- a/src/sciline/handler.py +++ b/src/sciline/handler.py @@ -17,7 +17,7 @@ class UnsatisfiedRequirement(Exception): class ErrorHandler(Protocol): """Error handling protocol for pipelines.""" - def handle_unsatisfied_requirement(self, tp: Key) -> Provider: + def handle_unsatisfied_requirement(self, tp: Key, *explanation: str) -> Provider: ... @@ -29,9 +29,9 @@ class HandleAsBuildTimeException(ErrorHandler): ensuring that errors are caught early, before starting costly computation. """ - def handle_unsatisfied_requirement(self, tp: Key) -> NoReturn: + def handle_unsatisfied_requirement(self, tp: Key, *explanation: str) -> NoReturn: """Raise an exception when a type cannot be provided.""" - raise UnsatisfiedRequirement('No provider found for type', tp) + raise UnsatisfiedRequirement('No provider found for type', tp, *explanation) class HandleAsComputeTimeException(ErrorHandler): @@ -42,11 +42,11 @@ class HandleAsComputeTimeException(ErrorHandler): visualization. This is helpful when visualizing a graph that is not yet complete. """ - def handle_unsatisfied_requirement(self, tp: Key) -> Provider: + def handle_unsatisfied_requirement(self, tp: Key, *explanation: str) -> Provider: """Return a function that raises an exception when called.""" def unsatisfied_sentinel() -> NoReturn: - raise UnsatisfiedRequirement('No provider found for type', tp) + raise UnsatisfiedRequirement('No provider found for type', tp, *explanation) return Provider( func=unsatisfied_sentinel, arg_spec=ArgSpec.null(), kind='unsatisfied' diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 9129043c..0e883df3 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -14,6 +14,7 @@ List, Mapping, Optional, + Set, Tuple, Type, TypeVar, @@ -39,6 +40,7 @@ from .scheduler import Scheduler from .series import Series from .typing import Graph, Item, Key, Label, get_optional, get_union +from .utils import keyname T = TypeVar('T') KeyType = TypeVar('KeyType', bound=Key) @@ -51,21 +53,68 @@ class AmbiguousProvider(Exception): """Raised when multiple providers are found for a type.""" -def _is_compatible_type_tuple( +def _extract_typevars_from_generic_type(t: type) -> Tuple[TypeVar, ...]: + """Returns the typevars that were used in the definition of a Generic type.""" + if not hasattr(t, '__orig_bases__'): + return () + return tuple( + chain(*(get_args(b) for b in t.__orig_bases__ if get_origin(b) == Generic)) + ) + + +def _find_all_typevars(t: Union[type, TypeVar]) -> Set[TypeVar]: + """Returns the set of all TypeVars in a type expression.""" + if isinstance(t, TypeVar): + return {t} + return set(chain(*map(_find_all_typevars, get_args(t)))) + + +def _find_bounds_to_make_compatible_type( + requested: Key, + provided: Key | TypeVar, +) -> Optional[Dict[TypeVar, Key]]: + """ + Check if a type is compatible to a provided type. + If the types are compatible, return a mapping from typevars to concrete types + that makes the provided type equal to the requested type. + """ + if provided == requested: + ret: Dict[TypeVar, Key] = {} + return ret + if isinstance(provided, TypeVar): + # If the type var has no constraints, accept anything + if not provided.__constraints__: + return {provided: requested} + for c in provided.__constraints__: + if _find_bounds_to_make_compatible_type(requested, c) is not None: + return {provided: requested} + if get_origin(provided) is not None: + if get_origin(provided) == get_origin(requested): + return _find_bounds_to_make_compatible_type_tuple( + get_args(requested), get_args(provided) + ) + return None + + +def _find_bounds_to_make_compatible_type_tuple( requested: tuple[Key, ...], provided: tuple[Key | TypeVar, ...], -) -> bool: +) -> Optional[Dict[TypeVar, Key]]: """ - Check if a tuple of requested types is compatible with a tuple of provided types. - - Types in the tuples must either by equal, or the provided type must be a TypeVar. + Check if a tuple of requested types is compatible with a tuple of provided types + and return a mapping from type vars to concrete types that makes all provided + types equal to their corresponding requested type. + If any of the types is not compatible, return None. """ - for req, prov in zip(requested, provided): - if isinstance(prov, TypeVar): - continue - if req != prov: - return False - return True + union: Dict[TypeVar, Key] = {} + for bound in map(_find_bounds_to_make_compatible_type, requested, provided): + # If no mapping from the type-var to a concrete type was found, + # or if the mapping is inconsistent, + # interrupt the search and report that no compatible types were found. + if bound is None or any(k in union and union[k] != bound[k] for k in bound): + return None + union.update(bound) + return union def _find_all_paths( @@ -494,6 +543,7 @@ def _get_provider( self, tp: Union[Type[T], Item[T]], handler: Optional[ErrorHandler] = None ) -> Tuple[Provider, Dict[TypeVar, Key]]: handler = handler or HandleAsBuildTimeException() + explanation: List[str] = [] if (provider := self._providers.get(tp)) is not None: return provider, {} elif (origin := get_origin(tp)) is not None and ( @@ -501,13 +551,14 @@ def _get_provider( ) is not None: requested = get_args(tp) matches = [ - (args, subprovider) + (subprovider, bound) for args, subprovider in subproviders.items() - if _is_compatible_type_tuple(requested, args) - ] - typevar_counts = [ - sum(1 for t in args if isinstance(t, TypeVar)) for args, _ in matches + if ( + bound := _find_bounds_to_make_compatible_type_tuple(requested, args) + ) + is not None ] + typevar_counts = [len(bound) for _, bound in matches] min_typevar_count = min(typevar_counts, default=0) matches = [ m @@ -516,20 +567,38 @@ def _get_provider( ] if len(matches) == 1: - args, provider = matches[0] - bound = { - arg: req - for arg, req in zip(args, requested) - if isinstance(arg, TypeVar) - } + provider, bound = matches[0] return provider, bound elif len(matches) > 1: - matching_providers = [m[1].location.name for m in matches] + matching_providers = [provider.location.name for provider, _ in matches] raise AmbiguousProvider( f"Multiple providers found for type {tp}." f" Matching providers are: {matching_providers}." ) - return handler.handle_unsatisfied_requirement(tp), {} + else: + typevars_in_expression = _extract_typevars_from_generic_type(origin) + if typevars_in_expression: + explanation = [ + ''.join( + map( + str, + ( + 'Note that ', + keyname(origin[typevars_in_expression]), + ' has constraints ', + ( + { + keyname(tv): tuple( + map(keyname, tv.__constraints__) + ) + for tv in typevars_in_expression + } + ), + ), + ) + ) + ] + return handler.handle_unsatisfied_requirement(tp, *explanation), {} def _get_unique_provider( self, tp: Union[Type[T], Item[T]], handler: ErrorHandler diff --git a/tests/pipeline_test.py b/tests/pipeline_test.py index f68a915b..91099ea6 100644 --- a/tests/pipeline_test.py +++ b/tests/pipeline_test.py @@ -1266,3 +1266,141 @@ def f(x: int): # type: ignore[no-untyped-def] with pytest.raises(ValueError, match='type-hint'): sl.Pipeline([f]) + + +def test_does_not_allow_type_argument_outside_of_constraints_flat() -> None: + T = TypeVar('T', int, float, str) + T2 = TypeVar('T2', int, float) + + @dataclass + class M(Generic[T]): + value: T + + def p1(value: T2) -> M[T2]: + return M(value) + + pipeline = sl.Pipeline((p1,)) + pipeline[str] = 'abc' + pipeline[int] = 123 + + pipeline.get(M[int]) + + with pytest.raises(sl.handler.UnsatisfiedRequirement): + pipeline.get(M[str]) + + +def test_does_not_allow_type_argument_outside_of_constraints_nested() -> None: + T = TypeVar('T', int, float, str) + + @dataclass + class M(Generic[T]): + value: T + + S = TypeVar('S', M[int], M[float], M[str]) + S2 = TypeVar('S2', M[int], M[float]) + + @dataclass + class N(Generic[S]): + value: S + + def p1(value: T) -> M[T]: + return M(value) + + def p2(value: S2) -> N[S2]: + return N(value) + + pipeline = sl.Pipeline((p1, p2)) + pipeline[str] = 'abc' + pipeline[int] = 123 + + pipeline.get(N[M[int]]) + + with pytest.raises(sl.handler.UnsatisfiedRequirement): + pipeline.get(N[M[str]]) + + +def test_constraints_nested_multiple_typevars() -> None: + T = TypeVar('T', int, float, str) + T2 = TypeVar('T2', int, float) + + @dataclass + class M(Generic[T]): + v: T + + S = TypeVar('S', M[int], M[float], M[str]) + S2 = TypeVar('S2', M[int], M[float]) + + @dataclass + class N(Generic[S, T]): + v1: S + v2: T + + def p1(v: T) -> M[T]: + return M(v) + + def p2(v1: S2, v2: T2) -> N[S2, T2]: + return N(v1, v2) + + pipeline = sl.Pipeline((p1, p2)) + pipeline[str] = 'abc' + pipeline[int] = 123 + pipeline[float] = 3.14 + + pipeline.get(N[M[float], int]) + pipeline.get(N[M[int], int]) + + with pytest.raises(sl.handler.UnsatisfiedRequirement): + pipeline.get(N[M[int], str]) + with pytest.raises(sl.handler.UnsatisfiedRequirement): + pipeline.get(N[M[str], float]) + + +def test_number_of_type_vars_defines_most_specialized() -> None: + Green = NewType('Green', str) + Blue = NewType('Blue', str) + Color = TypeVar('Color', Green, Blue) + + @dataclass + class Likes(Generic[Color]): + color: Color + + Preference = TypeVar('Preference') + + @dataclass + class Person(Generic[Preference, Color]): + preference: Preference + hatcolor: Color + provided_by: str + + def p(c: Color) -> Likes[Color]: + return Likes(c) + + def p0(p: Preference, c: Color) -> Person[Preference, Color]: + return Person(p, c, 'p0') + + def p1(c: Color) -> Person[Likes[Color], Color]: + return Person(Likes(c), c, 'p1') + + def p2(p: Preference) -> Person[Preference, Green]: + return Person(p, Green('g'), 'p2') + + pipeline = sl.Pipeline((p, p0, p1, p2)) + pipeline[Blue] = 'b' + pipeline[Green] = 'g' + + # only provided by p0 + assert pipeline.compute(Person[Likes[Green], Blue]) == Person( + Likes(Green('g')), Blue('b'), 'p0' + ) + # provided by p1 and p0 but p1 is preferred because it has fewer typevars + assert pipeline.compute(Person[Likes[Blue], Blue]) == Person( + Likes(Blue('b')), Blue('b'), 'p1' + ) + # provided by p2 and p0 but p2 is preferred because it has fewer typevars + assert pipeline.compute(Person[Likes[Blue], Green]) == Person( + Likes(Blue('b')), Green('g'), 'p2' + ) + + with pytest.raises(sl.AmbiguousProvider): + # provided by p1 and p2 with the same number of typevars + pipeline.get(Person[Likes[Green], Green])