diff --git a/hugr-py/src/hugr/hugr.py b/hugr-py/src/hugr/hugr.py index c067bc38c5..7631c5e903 100644 --- a/hugr-py/src/hugr/hugr.py +++ b/hugr-py/src/hugr/hugr.py @@ -145,7 +145,7 @@ def nodes(self) -> Iterable[tuple[Node, NodeData]]: """Iterator over nodes of the hugr and their data.""" return self.items() - def links(self) -> Iterable[tuple[OutPort, InPort]]: + def links(self) -> Iterator[tuple[OutPort, InPort]]: """Iterator over all the links in the HUGR. Returns: diff --git a/hugr-py/src/hugr/render.py b/hugr-py/src/hugr/render.py index 6022f65287..d252f28f82 100644 --- a/hugr-py/src/hugr/render.py +++ b/hugr-py/src/hugr/render.py @@ -2,12 +2,13 @@ from collections.abc import Iterable from dataclasses import dataclass +from typing import assert_never import graphviz as gv # type: ignore[import-untyped] from graphviz import Digraph from hugr.hugr import Hugr -from hugr.tys import ConstKind, Kind, OrderKind, ValueKind +from hugr.tys import CFKind, ConstKind, FunctionKind, Kind, OrderKind, ValueKind from .node_port import InPort, Node, OutPort @@ -86,7 +87,7 @@ def render(self, hugr: Hugr) -> Digraph: "margin": "0", "bgcolor": self.palette.background, } - if not (name := hugr[hugr.root].metadata.get("name")): + if not (name := hugr[hugr.root].metadata.get("name", None)): name = "" graph = gv.Digraph(name, strict=False) @@ -197,10 +198,6 @@ def _out_order_name(self, n: Node) -> str: def _viz_node(self, node: Node, hugr: Hugr, graph: Digraph) -> None: """Render a (possibly nested) node to a graphviz graph.""" - # TODO: Port the CFG special-case rendering from guppy, and use it here - # when a node is a CFG node. - # See https://github.com/CQCL/guppylang/blob/7d5106cd59ad452046d0dfffd10eea9d9b617431/guppylang/hugr_builder/visualise.py#L250 - meta = hugr[node].metadata if len(meta) > 0: data = "

" + "
".join( @@ -266,10 +263,12 @@ def _viz_link( color = self.palette.edge case OrderKind(): color = self.palette.dark - case ConstKind(): + case ConstKind() | FunctionKind(): color = self.palette.const - case _: + case CFKind(): color = self.palette.dark + case _: + assert_never(kind) graph.edge( self._out_port_name(src_port), diff --git a/hugr-py/tests/__snapshots__/test_hugr_build.ambr b/hugr-py/tests/__snapshots__/test_hugr_build.ambr index 52eb401356..dcd50978eb 100644 --- a/hugr-py/tests/__snapshots__/test_hugr_build.ambr +++ b/hugr-py/tests/__snapshots__/test_hugr_build.ambr @@ -1213,7 +1213,7 @@ > shape=plain] color="#1CADE4" label="" margin=10 } - 1:"out.0" -> 4:"in.1" [label="" arrowhead=none arrowsize=1.0 color=black fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] + 1:"out.0" -> 4:"in.1" [label="" arrowhead=none arrowsize=1.0 color="#77CEEF" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] 2:"out.0" -> 4:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] 4:"out.0" -> 3:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5] }