diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 4ed76328..cbd40a8e 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections import defaultdict +from collections.abc import Iterable from itertools import chain from typing import ( Any, @@ -10,7 +11,6 @@ Collection, Dict, Generic, - Iterable, List, Mapping, Optional, @@ -116,6 +116,22 @@ def _find_nodes_in_paths( return list(nodes) +def _is_multiple_keys( + keys: type | Iterable[type] | Item[T], +) -> bool: + # Cannot simply use isinstance(keys, Iterable) because that is True for + # generic aliases of iterable types, e.g., + # + # class Str(sl.Scope[Param, str], str): ... + # keys = Str[int] + # + # And isinstance(keys, type) does not work on its own because + # it is False for the above type. + return ( + not isinstance(keys, type) and not get_args(keys) and isinstance(keys, Iterable) + ) + + def provide_none() -> None: return None @@ -602,14 +618,14 @@ def compute(self, tp: Type[T]) -> T: ... @overload - def compute(self, tp: Tuple[Type[T], ...]) -> Dict[Type[T], T]: + def compute(self, tp: Iterable[Type[T]]) -> Dict[Type[T], T]: ... @overload def compute(self, tp: Item[T]) -> T: ... - def compute(self, tp: type | Tuple[type, ...] | Item[T]) -> Any: + def compute(self, tp: type | Iterable[type] | Item[T]) -> Any: """ Compute result for the given keys. @@ -618,12 +634,13 @@ def compute(self, tp: type | Tuple[type, ...] | Item[T]) -> Any: Parameters ---------- tp: - Type to compute the result for. Can be a single type or a tuple of types. + Type to compute the result for. + Can be a single type or an iterable of types. """ return self.get(tp).compute() def visualize( - self, tp: type | Tuple[type, ...], **kwargs: Any + self, tp: type | Iterable[type], **kwargs: Any ) -> graphviz.Digraph: # type: ignore[name-defined] # noqa: F821 """ Return a graphviz Digraph object representing the graph for the given keys. @@ -633,7 +650,8 @@ def visualize( Parameters ---------- tp: - Type to visualize the graph for. Can be a single type or a tuple of types. + Type to visualize the graph for. + Can be a single type or an iterable of types. kwargs: Keyword arguments passed to :py:class:`graphviz.Digraph`. """ @@ -641,7 +659,7 @@ def visualize( def get( self, - keys: type | Tuple[type, ...] | Item[T], + keys: type | Iterable[type] | Item[T], *, scheduler: Optional[Scheduler] = None, ) -> TaskGraph: @@ -651,19 +669,23 @@ def get( Parameters ---------- keys: - Type to compute the result for. Can be a single type or a tuple of types. + Type to compute the result for. + Can be a single type or an iterable of types. scheduler: Optional scheduler to use for computing the result. If not given, a :py:class:`NaiveScheduler` is used if `dask` is not installed, otherwise dask's threaded scheduler is used. """ - if isinstance(keys, tuple): + if _is_multiple_keys(keys): + keys = tuple(keys) # type: ignore[arg-type] graph: Graph = {} for t in keys: graph.update(self.build(t)) else: - graph = self.build(keys) - return TaskGraph(graph=graph, keys=keys, scheduler=scheduler) + graph = self.build(keys) # type: ignore[arg-type] + return TaskGraph( + graph=graph, keys=keys, scheduler=scheduler # type: ignore[arg-type] + ) @overload def bind_and_call(self, fns: Callable[..., T], /) -> T: diff --git a/tests/pipeline_test.py b/tests/pipeline_test.py index afa7dc0e..9c7155f2 100644 --- a/tests/pipeline_test.py +++ b/tests/pipeline_test.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from dataclasses import dataclass -from typing import Generic, List, NewType, TypeVar +from typing import Any, Callable, Generic, List, NewType, TypeVar import numpy as np import numpy.typing as npt @@ -638,9 +638,12 @@ def test_get_with_single_key_return_task_graph_that_computes_value() -> None: assert task.compute() == '3;1.5' -def test_get_with_key_tuple_return_task_graph_that_computes_dict_of_values() -> None: +@pytest.mark.parametrize('key_type', [tuple, list, iter]) +def test_get_with_key_iterable_return_task_graph_that_computes_dict_of_values( + key_type: Callable[[Any], Any], +) -> None: pipeline = sl.Pipeline([int_to_float, make_int]) - task = pipeline.get((float, int)) + task = pipeline.get(key_type((float, int))) assert task.compute() == {float: 1.5, int: 3}