From 30c81f163336ee18a108d0a037fc8917ddcef890 Mon Sep 17 00:00:00 2001 From: Jan-Lukas Wynen Date: Wed, 14 Feb 2024 13:35:23 +0100 Subject: [PATCH 1/2] Add TaskGraph.keys --- src/sciline/task_graph.py | 17 +++++++++++++++-- tests/task_graph_test.py | 25 +++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/sciline/task_graph.py b/src/sciline/task_graph.py index b2c2405e..bb60fb24 100644 --- a/src/sciline/task_graph.py +++ b/src/sciline/task_graph.py @@ -3,10 +3,10 @@ from __future__ import annotations from html import escape -from typing import Any, Optional, Sequence, Tuple, TypeVar, Union +from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union from .scheduler import DaskScheduler, NaiveScheduler, Scheduler -from .typing import Graph, Item +from .typing import Graph, Item, Key from .utils import keyname T = TypeVar("T") @@ -113,6 +113,19 @@ def compute( else: return self._scheduler.get(self._graph, [keys])[0] + def keys(self) -> Generator[Key, None, None]: + """Iterate over all keys of the graph. + + Yields all keys, i.e., the types of values that can be computed or are + provided as parameters. + + Returns + ------- + : + Iterable over keys. + """ + yield from self._graph.keys() + def visualize( self, **kwargs: Any ) -> graphviz.Digraph: # type: ignore[name-defined] # noqa: F821 diff --git a/tests/task_graph_test.py b/tests/task_graph_test.py index 05034f3a..f3959aec 100644 --- a/tests/task_graph_test.py +++ b/tests/task_graph_test.py @@ -1,11 +1,29 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from typing import NewType, TypeVar + import pytest import sciline as sl from sciline.task_graph import TaskGraph from sciline.typing import Graph +A = NewType('A', int) +B = NewType('B', int) +T = TypeVar('T', A, B) + + +class Str(sl.Scope[T, str], str): + ... + + +def to_string(x: T) -> Str[T]: + return Str[T](str(x)) + + +def repeat(a: A, s: Str[B]) -> list[str]: + return [s] * a + def as_float(x: int) -> float: return 0.5 * x @@ -53,3 +71,10 @@ def test_compute_raises_when_provided_with_key_not_in_graph() -> None: tg.compute(str) with pytest.raises(KeyError): tg.compute((str, float)) + + +def test_keys_iter() -> None: + pl = sl.Pipeline([to_string, repeat], params={A: 3, B: 4}) + tg = pl.get(list[str]) + assert len(list(tg.keys())) == 4 # there are no duplicates + assert set(tg.keys()) == {A, B, Str[B], list[str]} From f7c67a3b12a9c7bbf939e0b8a6a3bcdebbf055c7 Mon Sep 17 00:00:00 2001 From: Jan-Lukas Wynen Date: Thu, 15 Feb 2024 09:23:33 +0100 Subject: [PATCH 2/2] Rename keys->targets --- src/sciline/pipeline.py | 2 +- src/sciline/task_graph.py | 28 ++++++++++++++-------------- tests/task_graph_test.py | 17 ++++++++++------- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 0e883df3..98fff4bc 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -860,7 +860,7 @@ def get( else: graph = self.build(keys, handler=handler) # type: ignore[arg-type] return TaskGraph( - graph=graph, keys=keys, scheduler=scheduler # type: ignore[arg-type] + graph=graph, targets=keys, scheduler=scheduler # type: ignore[arg-type] ) @overload diff --git a/src/sciline/task_graph.py b/src/sciline/task_graph.py index bb60fb24..6d901eb5 100644 --- a/src/sciline/task_graph.py +++ b/src/sciline/task_graph.py @@ -70,11 +70,11 @@ def __init__( self, *, graph: Graph, - keys: Union[type, Tuple[type, ...], Item[T], Tuple[Item[T], ...]], + targets: Union[type, Tuple[type, ...], Item[T], Tuple[Item[T], ...]], scheduler: Optional[Scheduler] = None, ) -> None: self._graph = graph - self._keys = keys + self._keys = targets if scheduler is None: try: scheduler = DaskScheduler() @@ -84,7 +84,7 @@ def __init__( def compute( self, - keys: Optional[ + targets: Optional[ Union[type, Tuple[type, ...], Item[T], Tuple[Item[T], ...]] ] = None, ) -> Any: @@ -93,28 +93,28 @@ def compute( Parameters ---------- - keys: + targets: Optional list of keys to compute. This can be used to override the keys stored in the graph instance. Note that the keys must be present in the graph as intermediate results, otherwise KeyError is raised. Returns ------- - If ``keys`` is a single type, returns the single result that was computed. - If ``keys`` is a tuple of types, returns a dictionary with type as keys + If ``targets`` is a single type, returns the single result that was computed. + If ``targets`` is a tuple of types, returns a dictionary with type as keys and the corresponding results as values. - """ - if keys is None: - keys = self._keys - if isinstance(keys, tuple): - results = self._scheduler.get(self._graph, list(keys)) - return dict(zip(keys, results)) + if targets is None: + targets = self._keys + if isinstance(targets, tuple): + results = self._scheduler.get(self._graph, list(targets)) + return dict(zip(targets, results)) else: - return self._scheduler.get(self._graph, [keys])[0] + return self._scheduler.get(self._graph, [targets])[0] def keys(self) -> Generator[Key, None, None]: - """Iterate over all keys of the graph. + """ + Iterate over all keys of the graph. Yields all keys, i.e., the types of values that can be computed or are provided as parameters. diff --git a/tests/task_graph_test.py b/tests/task_graph_test.py index f3959aec..1c5fc8af 100644 --- a/tests/task_graph_test.py +++ b/tests/task_graph_test.py @@ -36,37 +36,40 @@ def make_task_graph() -> Graph: def test_default_scheduler_is_dask_when_dask_available() -> None: _ = pytest.importorskip("dask") - tg = TaskGraph(graph={}, keys=()) + tg = TaskGraph(graph={}, targets=()) assert isinstance(tg._scheduler, sl.scheduler.DaskScheduler) def test_compute_returns_value_when_initialized_with_single_key() -> None: graph = make_task_graph() - tg = TaskGraph(graph=graph, keys=float) + tg = TaskGraph(graph=graph, targets=float) assert tg.compute() == 0.5 def test_compute_returns_dict_when_initialized_with_key_tuple() -> None: graph = make_task_graph() - assert TaskGraph(graph=graph, keys=(float,)).compute() == {float: 0.5} - assert TaskGraph(graph=graph, keys=(float, int)).compute() == {float: 0.5, int: 1} + assert TaskGraph(graph=graph, targets=(float,)).compute() == {float: 0.5} + assert TaskGraph(graph=graph, targets=(float, int)).compute() == { + float: 0.5, + int: 1, + } def test_compute_returns_value_when_provided_with_single_key() -> None: graph = make_task_graph() - tg = TaskGraph(graph=graph, keys=float) + tg = TaskGraph(graph=graph, targets=float) assert tg.compute(int) == 1 def test_compute_returns_dict_when_provided_with_key_tuple() -> None: graph = make_task_graph() - tg = TaskGraph(graph=graph, keys=float) + tg = TaskGraph(graph=graph, targets=float) assert tg.compute((int, float)) == {int: 1, float: 0.5} def test_compute_raises_when_provided_with_key_not_in_graph() -> None: graph = make_task_graph() - tg = TaskGraph(graph=graph, keys=float) + tg = TaskGraph(graph=graph, targets=float) with pytest.raises(KeyError): tg.compute(str) with pytest.raises(KeyError):