Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mode argument to visualize for more compact data or task graph display #187

Merged
merged 11 commits into from
Nov 1, 2024
36 changes: 33 additions & 3 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
from collections.abc import Callable, Hashable, Iterable, Sequence
from itertools import chain
from types import UnionType
from typing import TYPE_CHECKING, Any, TypeVar, get_args, get_type_hints, overload
from typing import (
TYPE_CHECKING,
Any,
Literal,
TypeVar,
get_args,
get_type_hints,
overload,
)

from ._provider import Provider, ToProvider
from ._utils import key_name
Expand Down Expand Up @@ -91,7 +99,13 @@ def compute(self, tp: type | Iterable[type] | UnionType, **kwargs: Any) -> Any:
return self.get(tp, **kwargs).compute()

def visualize(
self, tp: type | Iterable[type] | None = None, **kwargs: Any
self,
tp: type | Iterable[type] | None = None,
compact: bool = False,
mode: Literal['data', 'task', 'both'] = 'data',
cluster_generics: bool = True,
cluster_color: str | None = '#f0f0ff',
**kwargs: Any,
) -> graphviz.Digraph:
"""
Return a graphviz Digraph object representing the graph for the given keys.
Expand All @@ -103,12 +117,28 @@ def visualize(
tp:
Type to visualize the graph for.
Can be a single type or an iterable of types.
compact:
If True, parameter-table-dependent branches are collapsed into a single copy
of the branch. Recommended for large graphs with long parameter tables.
mode:
If 'data', only data nodes are shown. If 'task', only task nodes and input
data nodes are shown. If 'both', all nodes are shown.
cluster_generics:
If True, generic products are grouped into clusters.
cluster_color:
Background color of clusters. If None, clusters are dotted.
kwargs:
Keyword arguments passed to :py:class:`graphviz.Digraph`.
"""
if tp is None:
tp = self.output_keys()
return self.get(tp, handler=HandleAsComputeTimeException()).visualize(**kwargs)
return self.get(tp, handler=HandleAsComputeTimeException()).visualize(
compact=compact,
mode=mode,
cluster_generics=cluster_generics,
cluster_color=cluster_color,
**kwargs,
)

def get(
self,
Expand Down
30 changes: 27 additions & 3 deletions src/sciline/task_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from collections.abc import Generator, Hashable, Sequence
from html import escape
from typing import Any, TypeVar
from typing import Any, Literal, TypeVar

from ._utils import key_name
from .scheduler import DaskScheduler, NaiveScheduler, Scheduler
Expand Down Expand Up @@ -126,18 +126,42 @@ def keys(self) -> Generator[Key, None, None]:
"""
yield from self._graph.keys()

def visualize(self, **kwargs: Any) -> graphviz.Digraph: # type: ignore[name-defined] # noqa: F821
def visualize(
self,
compact: bool = False,
mode: Literal['data', 'task', 'both'] = 'data',
cluster_generics: bool = True,
cluster_color: str | None = '#f0f0ff',
**kwargs: Any,
) -> graphviz.Digraph: # type: ignore[name-defined] # noqa: F821
"""
Return a graphviz Digraph object representing the graph.

Parameters
----------
compact:
If True, parameter-table-dependent branches are collapsed into a single copy
of the branch. Recommended for large graphs with long parameter tables.
mode:
If 'data', only data nodes are shown. If 'task', only task nodes and input
data nodes are shown. If 'both', all nodes are shown.
cluster_generics:
If True, generic products are grouped into clusters.
cluster_color:
Background color of clusters. If None, clusters are dotted.
kwargs:
Keyword arguments passed to :py:class:`graphviz.Digraph`.
"""
from .visualize import to_graphviz

return to_graphviz(self._graph, **kwargs)
return to_graphviz(
self._graph,
compact=compact,
mode=mode,
cluster_generics=cluster_generics,
cluster_color=cluster_color,
**kwargs,
)

def serialize(self) -> dict[str, Json]:
"""Serialize the graph to JSON.
Expand Down
112 changes: 95 additions & 17 deletions src/sciline/visualize.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
import html
from collections.abc import Hashable
from dataclasses import dataclass
from typing import Any, get_args, get_origin
from typing import Any, Literal, get_args, get_origin

import cyclebane
from graphviz import Digraph
Expand Down Expand Up @@ -31,6 +32,7 @@ class FormattedProvider:
def to_graphviz(
graph: Graph,
compact: bool = False,
mode: Literal['data', 'task', 'both'] = 'data',
cluster_generics: bool = True,
cluster_color: str | None = '#f0f0ff',
**kwargs: Any,
Expand All @@ -45,6 +47,9 @@ def to_graphviz(
compact:
If True, parameter-table-dependent branches are collapsed into a single copy
of the branch. Recommended for large graphs with long parameter tables.
mode:
If 'data', only data nodes are shown. If 'task', only task nodes and input data
nodes are shown. If 'both', all nodes are shown.
cluster_generics:
If True, generic products are grouped into clusters.
cluster_color:
Expand All @@ -53,6 +58,28 @@ def to_graphviz(
Keyword arguments passed to :py:class:`graphviz.Digraph`.
"""
dot = Digraph(strict=True, **kwargs)
if dot.graph_attr.get('rankdir', 'TB') == 'LR':
# Significant horizontal space helps distinguishing edges
dot.graph_attr['ranksep'] = '1'
# Little vertical space
dot.graph_attr['nodesep'] = '0.05'
# Avoiding edges connecting to top/bottom reduces edge clutter in larger graphs
dot.edge_attr['tailport'] = 'e'
dot.edge_attr['headport'] = 'w'
else:
dot.graph_attr['ranksep'] = '0.5'
dot.graph_attr['nodesep'] = '0.1'
# With tailport='s' we get more curved edges, so we omit it. In larger graphs
# this still seems to happen though, may need revisiting.
# Nodes are wide in west-east direction, so *not* connecting to headport='n'
# looks better
dot.node_attr.update({'height': '0', 'width': '0'})
# Ensure user can override defaults
dot.node_attr.update(kwargs.get('node_attr', {}))
dot.edge_attr.update(kwargs.get('edge_attr', {}))
dot.graph_attr.update(kwargs.get('graph_attr', {}))
# Compound is required for connecting edges to clusters
dot.graph_attr['compound'] = 'true'
formatted_graph = _format_graph(graph, compact=compact)
ordered_graph = dict(
sorted(formatted_graph.items(), key=lambda item: item[1].ret.name)
Expand All @@ -69,7 +96,13 @@ def to_graphviz(
dot_subgraph.attr(style='dotted')
else:
dot_subgraph.attr(style='filled', color=cluster_color)
_add_subgraph(subgraph, dot, dot_subgraph)
# For keys such as MyType[int] we show MyType only once as the cluster
# label. The nodes within the cluster will only show to bit inside [].
# This save a lot of horizontal space in the graph in LR mode and
# duplication and clutter in general.
origin = next(iter(subgraph.values())).ret.name.split('[')[0]
dot_subgraph.attr(label=f'{origin}')
_add_subgraph(subgraph, dot, dot_subgraph, mode=mode)
return dot


Expand All @@ -82,28 +115,73 @@ def _to_subgraphs(graph: FormattedGraph) -> dict[str, FormattedGraph]:
return subgraphs


def _add_subgraph(graph: FormattedGraph, dot: Digraph, subgraph: Digraph) -> None:
def _add_subgraph(
graph: FormattedGraph,
dot: Digraph,
subgraph: Digraph,
mode: Literal['data', 'task', 'both'],
) -> None:
cluster = subgraph.name is not None
cluster_connected = []
common_provider = len(graph) > 1 and len({v.name for v in graph.values()}) == 1
for p, formatted_p in graph.items():
ret_name = formatted_p.ret.name
if cluster:
# Remove the origin from the name if we are in a cluster, as it is shown
# as the cluster label
split = ret_name[ret_name.index('[') :]
# The nodes within the cluster use slightly smaller text.
name = f'<<font point-size="12">{split}</font>>'
else:
name = f'<{ret_name}>'
if mode == 'data' and formatted_p.kind == 'function':
# Show provider name in data mode
via_name = html.escape(formatted_p.name)
via = f'<font point-size="11">via:<i>{via_name}</i></font>'
if common_provider:
origin = ret_name.split('[')[0]
subgraph.attr(label=f'<{origin}<br/>{via}>')
else:
name = f'{name[:-1]}<br/>{via}>'
shape = 'box3d' if formatted_p.ret.collapsed else 'rectangle'
if formatted_p.kind == 'unsatisfied':
subgraph.node(
formatted_p.ret.name,
formatted_p.ret.name,
shape='box3d' if formatted_p.ret.collapsed else 'rectangle',
ret_name,
name,
shape=shape,
color='red',
fontcolor='red', # Set text color to red
fontcolor='red',
style='dashed',
)
else:
subgraph.node(
formatted_p.ret.name,
formatted_p.ret.name,
shape='box3d' if formatted_p.ret.collapsed else 'rectangle',
)
elif mode != 'task' or formatted_p.kind == 'parameter':
subgraph.node(ret_name, name, shape=shape)
if formatted_p.kind == 'function':
dot.node(p, formatted_p.name, shape='ellipse')
for arg in formatted_p.args:
dot.edge(arg.name, p)
dot.edge(p, formatted_p.ret.name)
if mode == 'both':
dot.node(p, formatted_p.name, shape='ellipse')
for arg in formatted_p.args:
dot.edge(arg.name, p)
dot.edge(p, ret_name)
elif mode == 'task':
p = ret_name
dot.node(p, formatted_p.name, shape='ellipse')
for arg in formatted_p.args:
dot.edge(arg.name, p)
elif mode == 'data':
for arg in formatted_p.args:
if cluster and common_provider and '[' not in arg.name:
# Avoid duplicate arrows to subnodes if all providers are the
# same and the argument is not a generic
if arg.name not in cluster_connected:
dot.edge(
arg.name,
ret_name,
lhead=subgraph.name,
# Thick pen to indicate multiple connections
penwidth='2.0',
)
cluster_connected.append(arg.name)
else:
dot.edge(arg.name, ret_name)
# else: Do not draw dummy providers created by Pipeline when setting instances
SimonHeybrock marked this conversation as resolved.
Show resolved Hide resolved

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the curved arrows are always better. Comparing the new to old figure, while there are less boxes and duplication in the new figure, I find it harder to follow the arrows.

New:
Screenshot at 2024-10-30 09-25-48

Old:
Screenshot at 2024-10-30 09-26-06

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, why are they curved now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoided some, see update.


SimonHeybrock marked this conversation as resolved.
Show resolved Hide resolved
Expand Down