diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py
index b29bea55..3107caa0 100644
--- a/src/sciline/pipeline.py
+++ b/src/sciline/pipeline.py
@@ -5,7 +5,15 @@
from collections.abc import Callable, Hashable, Iterable, Sequence
from itertools import chain
from types import UnionType
-from typing import TYPE_CHECKING, Any, TypeVar, get_args, get_type_hints, overload
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Literal,
+ TypeVar,
+ get_args,
+ get_type_hints,
+ overload,
+)
from ._provider import Provider, ToProvider
from ._utils import key_name
@@ -91,7 +99,13 @@ def compute(self, tp: type | Iterable[type] | UnionType, **kwargs: Any) -> Any:
return self.get(tp, **kwargs).compute()
def visualize(
- self, tp: type | Iterable[type] | None = None, **kwargs: Any
+ self,
+ tp: type | Iterable[type] | None = None,
+ compact: bool = False,
+ mode: Literal['data', 'task', 'both'] = 'data',
+ cluster_generics: bool = True,
+ cluster_color: str | None = '#f0f0ff',
+ **kwargs: Any,
) -> graphviz.Digraph:
"""
Return a graphviz Digraph object representing the graph for the given keys.
@@ -103,12 +117,28 @@ def visualize(
tp:
Type to visualize the graph for.
Can be a single type or an iterable of types.
+ compact:
+ If True, parameter-table-dependent branches are collapsed into a single copy
+ of the branch. Recommended for large graphs with long parameter tables.
+ mode:
+ If 'data', only data nodes are shown. If 'task', only task nodes and input
+ data nodes are shown. If 'both', all nodes are shown.
+ cluster_generics:
+ If True, generic products are grouped into clusters.
+ cluster_color:
+ Background color of clusters. If None, clusters are dotted.
kwargs:
Keyword arguments passed to :py:class:`graphviz.Digraph`.
"""
if tp is None:
tp = self.output_keys()
- return self.get(tp, handler=HandleAsComputeTimeException()).visualize(**kwargs)
+ return self.get(tp, handler=HandleAsComputeTimeException()).visualize(
+ compact=compact,
+ mode=mode,
+ cluster_generics=cluster_generics,
+ cluster_color=cluster_color,
+ **kwargs,
+ )
def get(
self,
diff --git a/src/sciline/task_graph.py b/src/sciline/task_graph.py
index 51f51ff5..27df1c92 100644
--- a/src/sciline/task_graph.py
+++ b/src/sciline/task_graph.py
@@ -4,7 +4,7 @@
from collections.abc import Generator, Hashable, Sequence
from html import escape
-from typing import Any, TypeVar
+from typing import Any, Literal, TypeVar
from ._utils import key_name
from .scheduler import DaskScheduler, NaiveScheduler, Scheduler
@@ -126,18 +126,42 @@ def keys(self) -> Generator[Key, None, None]:
"""
yield from self._graph.keys()
- def visualize(self, **kwargs: Any) -> graphviz.Digraph: # type: ignore[name-defined] # noqa: F821
+ def visualize(
+ self,
+ compact: bool = False,
+ mode: Literal['data', 'task', 'both'] = 'data',
+ cluster_generics: bool = True,
+ cluster_color: str | None = '#f0f0ff',
+ **kwargs: Any,
+ ) -> graphviz.Digraph: # type: ignore[name-defined] # noqa: F821
"""
Return a graphviz Digraph object representing the graph.
Parameters
----------
+ compact:
+ If True, parameter-table-dependent branches are collapsed into a single copy
+ of the branch. Recommended for large graphs with long parameter tables.
+ mode:
+ If 'data', only data nodes are shown. If 'task', only task nodes and input
+ data nodes are shown. If 'both', all nodes are shown.
+ cluster_generics:
+ If True, generic products are grouped into clusters.
+ cluster_color:
+ Background color of clusters. If None, clusters are dotted.
kwargs:
Keyword arguments passed to :py:class:`graphviz.Digraph`.
"""
from .visualize import to_graphviz
- return to_graphviz(self._graph, **kwargs)
+ return to_graphviz(
+ self._graph,
+ compact=compact,
+ mode=mode,
+ cluster_generics=cluster_generics,
+ cluster_color=cluster_color,
+ **kwargs,
+ )
def serialize(self) -> dict[str, Json]:
"""Serialize the graph to JSON.
diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py
index 0e78c307..8055a8ef 100644
--- a/src/sciline/visualize.py
+++ b/src/sciline/visualize.py
@@ -1,8 +1,9 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
+import html
from collections.abc import Hashable
from dataclasses import dataclass
-from typing import Any, get_args, get_origin
+from typing import Any, Literal, get_args, get_origin
import cyclebane
from graphviz import Digraph
@@ -31,6 +32,7 @@ class FormattedProvider:
def to_graphviz(
graph: Graph,
compact: bool = False,
+ mode: Literal['data', 'task', 'both'] = 'data',
cluster_generics: bool = True,
cluster_color: str | None = '#f0f0ff',
**kwargs: Any,
@@ -45,6 +47,9 @@ def to_graphviz(
compact:
If True, parameter-table-dependent branches are collapsed into a single copy
of the branch. Recommended for large graphs with long parameter tables.
+ mode:
+ If 'data', only data nodes are shown. If 'task', only task nodes and input data
+ nodes are shown. If 'both', all nodes are shown.
cluster_generics:
If True, generic products are grouped into clusters.
cluster_color:
@@ -53,6 +58,28 @@ def to_graphviz(
Keyword arguments passed to :py:class:`graphviz.Digraph`.
"""
dot = Digraph(strict=True, **kwargs)
+ if dot.graph_attr.get('rankdir', 'TB') == 'LR':
+ # Significant horizontal space helps distinguishing edges
+ dot.graph_attr['ranksep'] = '1'
+ # Little vertical space
+ dot.graph_attr['nodesep'] = '0.05'
+ # Avoiding edges connecting to top/bottom reduces edge clutter in larger graphs
+ dot.edge_attr['tailport'] = 'e'
+ dot.edge_attr['headport'] = 'w'
+ else:
+ dot.graph_attr['ranksep'] = '0.5'
+ dot.graph_attr['nodesep'] = '0.1'
+ # With tailport='s' we get more curved edges, so we omit it. In larger graphs
+ # this still seems to happen though, may need revisiting.
+ # Nodes are wide in west-east direction, so *not* connecting to headport='n'
+ # looks better
+ dot.node_attr.update({'height': '0', 'width': '0'})
+ # Ensure user can override defaults
+ dot.node_attr.update(kwargs.get('node_attr', {}))
+ dot.edge_attr.update(kwargs.get('edge_attr', {}))
+ dot.graph_attr.update(kwargs.get('graph_attr', {}))
+ # Compound is required for connecting edges to clusters
+ dot.graph_attr['compound'] = 'true'
formatted_graph = _format_graph(graph, compact=compact)
ordered_graph = dict(
sorted(formatted_graph.items(), key=lambda item: item[1].ret.name)
@@ -69,7 +96,13 @@ def to_graphviz(
dot_subgraph.attr(style='dotted')
else:
dot_subgraph.attr(style='filled', color=cluster_color)
- _add_subgraph(subgraph, dot, dot_subgraph)
+ # For keys such as MyType[int] we show MyType only once as the cluster
+ # label. The nodes within the cluster will only show to bit inside [].
+ # This save a lot of horizontal space in the graph in LR mode and
+ # duplication and clutter in general.
+ origin = next(iter(subgraph.values())).ret.name.split('[')[0]
+ dot_subgraph.attr(label=f'{origin}')
+ _add_subgraph(subgraph, dot, dot_subgraph, mode=mode)
return dot
@@ -82,28 +115,73 @@ def _to_subgraphs(graph: FormattedGraph) -> dict[str, FormattedGraph]:
return subgraphs
-def _add_subgraph(graph: FormattedGraph, dot: Digraph, subgraph: Digraph) -> None:
+def _add_subgraph(
+ graph: FormattedGraph,
+ dot: Digraph,
+ subgraph: Digraph,
+ mode: Literal['data', 'task', 'both'],
+) -> None:
+ cluster = subgraph.name is not None
+ cluster_connected = []
+ common_provider = len(graph) > 1 and len({v.name for v in graph.values()}) == 1
for p, formatted_p in graph.items():
+ ret_name = formatted_p.ret.name
+ if cluster:
+ # Remove the origin from the name if we are in a cluster, as it is shown
+ # as the cluster label
+ split = ret_name[ret_name.index('[') :]
+ # The nodes within the cluster use slightly smaller text.
+ name = f'<{split}>'
+ else:
+ name = f'<{ret_name}>'
+ if mode == 'data' and formatted_p.kind == 'function':
+ # Show provider name in data mode
+ via_name = html.escape(formatted_p.name)
+ via = f'via:{via_name}'
+ if common_provider:
+ origin = ret_name.split('[')[0]
+ subgraph.attr(label=f'<{origin}
{via}>')
+ else:
+ name = f'{name[:-1]}
{via}>'
+ shape = 'box3d' if formatted_p.ret.collapsed else 'rectangle'
if formatted_p.kind == 'unsatisfied':
subgraph.node(
- formatted_p.ret.name,
- formatted_p.ret.name,
- shape='box3d' if formatted_p.ret.collapsed else 'rectangle',
+ ret_name,
+ name,
+ shape=shape,
color='red',
- fontcolor='red', # Set text color to red
+ fontcolor='red',
style='dashed',
)
- else:
- subgraph.node(
- formatted_p.ret.name,
- formatted_p.ret.name,
- shape='box3d' if formatted_p.ret.collapsed else 'rectangle',
- )
+ elif mode != 'task' or formatted_p.kind == 'parameter':
+ subgraph.node(ret_name, name, shape=shape)
if formatted_p.kind == 'function':
- dot.node(p, formatted_p.name, shape='ellipse')
- for arg in formatted_p.args:
- dot.edge(arg.name, p)
- dot.edge(p, formatted_p.ret.name)
+ if mode == 'both':
+ dot.node(p, formatted_p.name, shape='ellipse')
+ for arg in formatted_p.args:
+ dot.edge(arg.name, p)
+ dot.edge(p, ret_name)
+ elif mode == 'task':
+ p = ret_name
+ dot.node(p, formatted_p.name, shape='ellipse')
+ for arg in formatted_p.args:
+ dot.edge(arg.name, p)
+ elif mode == 'data':
+ for arg in formatted_p.args:
+ if cluster and common_provider and '[' not in arg.name:
+ # Avoid duplicate arrows to subnodes if all providers are the
+ # same and the argument is not a generic
+ if arg.name not in cluster_connected:
+ dot.edge(
+ arg.name,
+ ret_name,
+ lhead=subgraph.name,
+ # Thick pen to indicate multiple connections
+ penwidth='2.0',
+ )
+ cluster_connected.append(arg.name)
+ else:
+ dot.edge(arg.name, ret_name)
# else: Do not draw dummy providers created by Pipeline when setting instances