diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index b29bea55..3107caa0 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -5,7 +5,15 @@ from collections.abc import Callable, Hashable, Iterable, Sequence from itertools import chain from types import UnionType -from typing import TYPE_CHECKING, Any, TypeVar, get_args, get_type_hints, overload +from typing import ( + TYPE_CHECKING, + Any, + Literal, + TypeVar, + get_args, + get_type_hints, + overload, +) from ._provider import Provider, ToProvider from ._utils import key_name @@ -91,7 +99,13 @@ def compute(self, tp: type | Iterable[type] | UnionType, **kwargs: Any) -> Any: return self.get(tp, **kwargs).compute() def visualize( - self, tp: type | Iterable[type] | None = None, **kwargs: Any + self, + tp: type | Iterable[type] | None = None, + compact: bool = False, + mode: Literal['data', 'task', 'both'] = 'data', + cluster_generics: bool = True, + cluster_color: str | None = '#f0f0ff', + **kwargs: Any, ) -> graphviz.Digraph: """ Return a graphviz Digraph object representing the graph for the given keys. @@ -103,12 +117,28 @@ def visualize( tp: Type to visualize the graph for. Can be a single type or an iterable of types. + compact: + If True, parameter-table-dependent branches are collapsed into a single copy + of the branch. Recommended for large graphs with long parameter tables. + mode: + If 'data', only data nodes are shown. If 'task', only task nodes and input + data nodes are shown. If 'both', all nodes are shown. + cluster_generics: + If True, generic products are grouped into clusters. + cluster_color: + Background color of clusters. If None, clusters are dotted. kwargs: Keyword arguments passed to :py:class:`graphviz.Digraph`. """ if tp is None: tp = self.output_keys() - return self.get(tp, handler=HandleAsComputeTimeException()).visualize(**kwargs) + return self.get(tp, handler=HandleAsComputeTimeException()).visualize( + compact=compact, + mode=mode, + cluster_generics=cluster_generics, + cluster_color=cluster_color, + **kwargs, + ) def get( self, diff --git a/src/sciline/task_graph.py b/src/sciline/task_graph.py index 51f51ff5..27df1c92 100644 --- a/src/sciline/task_graph.py +++ b/src/sciline/task_graph.py @@ -4,7 +4,7 @@ from collections.abc import Generator, Hashable, Sequence from html import escape -from typing import Any, TypeVar +from typing import Any, Literal, TypeVar from ._utils import key_name from .scheduler import DaskScheduler, NaiveScheduler, Scheduler @@ -126,18 +126,42 @@ def keys(self) -> Generator[Key, None, None]: """ yield from self._graph.keys() - def visualize(self, **kwargs: Any) -> graphviz.Digraph: # type: ignore[name-defined] # noqa: F821 + def visualize( + self, + compact: bool = False, + mode: Literal['data', 'task', 'both'] = 'data', + cluster_generics: bool = True, + cluster_color: str | None = '#f0f0ff', + **kwargs: Any, + ) -> graphviz.Digraph: # type: ignore[name-defined] # noqa: F821 """ Return a graphviz Digraph object representing the graph. Parameters ---------- + compact: + If True, parameter-table-dependent branches are collapsed into a single copy + of the branch. Recommended for large graphs with long parameter tables. + mode: + If 'data', only data nodes are shown. If 'task', only task nodes and input + data nodes are shown. If 'both', all nodes are shown. + cluster_generics: + If True, generic products are grouped into clusters. + cluster_color: + Background color of clusters. If None, clusters are dotted. kwargs: Keyword arguments passed to :py:class:`graphviz.Digraph`. """ from .visualize import to_graphviz - return to_graphviz(self._graph, **kwargs) + return to_graphviz( + self._graph, + compact=compact, + mode=mode, + cluster_generics=cluster_generics, + cluster_color=cluster_color, + **kwargs, + ) def serialize(self) -> dict[str, Json]: """Serialize the graph to JSON. diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index 0e78c307..8055a8ef 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +import html from collections.abc import Hashable from dataclasses import dataclass -from typing import Any, get_args, get_origin +from typing import Any, Literal, get_args, get_origin import cyclebane from graphviz import Digraph @@ -31,6 +32,7 @@ class FormattedProvider: def to_graphviz( graph: Graph, compact: bool = False, + mode: Literal['data', 'task', 'both'] = 'data', cluster_generics: bool = True, cluster_color: str | None = '#f0f0ff', **kwargs: Any, @@ -45,6 +47,9 @@ def to_graphviz( compact: If True, parameter-table-dependent branches are collapsed into a single copy of the branch. Recommended for large graphs with long parameter tables. + mode: + If 'data', only data nodes are shown. If 'task', only task nodes and input data + nodes are shown. If 'both', all nodes are shown. cluster_generics: If True, generic products are grouped into clusters. cluster_color: @@ -53,6 +58,28 @@ def to_graphviz( Keyword arguments passed to :py:class:`graphviz.Digraph`. """ dot = Digraph(strict=True, **kwargs) + if dot.graph_attr.get('rankdir', 'TB') == 'LR': + # Significant horizontal space helps distinguishing edges + dot.graph_attr['ranksep'] = '1' + # Little vertical space + dot.graph_attr['nodesep'] = '0.05' + # Avoiding edges connecting to top/bottom reduces edge clutter in larger graphs + dot.edge_attr['tailport'] = 'e' + dot.edge_attr['headport'] = 'w' + else: + dot.graph_attr['ranksep'] = '0.5' + dot.graph_attr['nodesep'] = '0.1' + # With tailport='s' we get more curved edges, so we omit it. In larger graphs + # this still seems to happen though, may need revisiting. + # Nodes are wide in west-east direction, so *not* connecting to headport='n' + # looks better + dot.node_attr.update({'height': '0', 'width': '0'}) + # Ensure user can override defaults + dot.node_attr.update(kwargs.get('node_attr', {})) + dot.edge_attr.update(kwargs.get('edge_attr', {})) + dot.graph_attr.update(kwargs.get('graph_attr', {})) + # Compound is required for connecting edges to clusters + dot.graph_attr['compound'] = 'true' formatted_graph = _format_graph(graph, compact=compact) ordered_graph = dict( sorted(formatted_graph.items(), key=lambda item: item[1].ret.name) @@ -69,7 +96,13 @@ def to_graphviz( dot_subgraph.attr(style='dotted') else: dot_subgraph.attr(style='filled', color=cluster_color) - _add_subgraph(subgraph, dot, dot_subgraph) + # For keys such as MyType[int] we show MyType only once as the cluster + # label. The nodes within the cluster will only show to bit inside []. + # This save a lot of horizontal space in the graph in LR mode and + # duplication and clutter in general. + origin = next(iter(subgraph.values())).ret.name.split('[')[0] + dot_subgraph.attr(label=f'{origin}') + _add_subgraph(subgraph, dot, dot_subgraph, mode=mode) return dot @@ -82,28 +115,73 @@ def _to_subgraphs(graph: FormattedGraph) -> dict[str, FormattedGraph]: return subgraphs -def _add_subgraph(graph: FormattedGraph, dot: Digraph, subgraph: Digraph) -> None: +def _add_subgraph( + graph: FormattedGraph, + dot: Digraph, + subgraph: Digraph, + mode: Literal['data', 'task', 'both'], +) -> None: + cluster = subgraph.name is not None + cluster_connected = [] + common_provider = len(graph) > 1 and len({v.name for v in graph.values()}) == 1 for p, formatted_p in graph.items(): + ret_name = formatted_p.ret.name + if cluster: + # Remove the origin from the name if we are in a cluster, as it is shown + # as the cluster label + split = ret_name[ret_name.index('[') :] + # The nodes within the cluster use slightly smaller text. + name = f'<{split}>' + else: + name = f'<{ret_name}>' + if mode == 'data' and formatted_p.kind == 'function': + # Show provider name in data mode + via_name = html.escape(formatted_p.name) + via = f'via:{via_name}' + if common_provider: + origin = ret_name.split('[')[0] + subgraph.attr(label=f'<{origin}
{via}>') + else: + name = f'{name[:-1]}
{via}>' + shape = 'box3d' if formatted_p.ret.collapsed else 'rectangle' if formatted_p.kind == 'unsatisfied': subgraph.node( - formatted_p.ret.name, - formatted_p.ret.name, - shape='box3d' if formatted_p.ret.collapsed else 'rectangle', + ret_name, + name, + shape=shape, color='red', - fontcolor='red', # Set text color to red + fontcolor='red', style='dashed', ) - else: - subgraph.node( - formatted_p.ret.name, - formatted_p.ret.name, - shape='box3d' if formatted_p.ret.collapsed else 'rectangle', - ) + elif mode != 'task' or formatted_p.kind == 'parameter': + subgraph.node(ret_name, name, shape=shape) if formatted_p.kind == 'function': - dot.node(p, formatted_p.name, shape='ellipse') - for arg in formatted_p.args: - dot.edge(arg.name, p) - dot.edge(p, formatted_p.ret.name) + if mode == 'both': + dot.node(p, formatted_p.name, shape='ellipse') + for arg in formatted_p.args: + dot.edge(arg.name, p) + dot.edge(p, ret_name) + elif mode == 'task': + p = ret_name + dot.node(p, formatted_p.name, shape='ellipse') + for arg in formatted_p.args: + dot.edge(arg.name, p) + elif mode == 'data': + for arg in formatted_p.args: + if cluster and common_provider and '[' not in arg.name: + # Avoid duplicate arrows to subnodes if all providers are the + # same and the argument is not a generic + if arg.name not in cluster_connected: + dot.edge( + arg.name, + ret_name, + lhead=subgraph.name, + # Thick pen to indicate multiple connections + penwidth='2.0', + ) + cluster_connected.append(arg.name) + else: + dot.edge(arg.name, ret_name) # else: Do not draw dummy providers created by Pipeline when setting instances