diff --git a/hugr-py/src/hugr/hugr.py b/hugr-py/src/hugr/hugr.py
index c067bc38c5..7631c5e903 100644
--- a/hugr-py/src/hugr/hugr.py
+++ b/hugr-py/src/hugr/hugr.py
@@ -145,7 +145,7 @@ def nodes(self) -> Iterable[tuple[Node, NodeData]]:
"""Iterator over nodes of the hugr and their data."""
return self.items()
- def links(self) -> Iterable[tuple[OutPort, InPort]]:
+ def links(self) -> Iterator[tuple[OutPort, InPort]]:
"""Iterator over all the links in the HUGR.
Returns:
diff --git a/hugr-py/src/hugr/render.py b/hugr-py/src/hugr/render.py
index 6022f65287..d252f28f82 100644
--- a/hugr-py/src/hugr/render.py
+++ b/hugr-py/src/hugr/render.py
@@ -2,12 +2,13 @@
from collections.abc import Iterable
from dataclasses import dataclass
+from typing import assert_never
import graphviz as gv # type: ignore[import-untyped]
from graphviz import Digraph
from hugr.hugr import Hugr
-from hugr.tys import ConstKind, Kind, OrderKind, ValueKind
+from hugr.tys import CFKind, ConstKind, FunctionKind, Kind, OrderKind, ValueKind
from .node_port import InPort, Node, OutPort
@@ -86,7 +87,7 @@ def render(self, hugr: Hugr) -> Digraph:
"margin": "0",
"bgcolor": self.palette.background,
}
- if not (name := hugr[hugr.root].metadata.get("name")):
+ if not (name := hugr[hugr.root].metadata.get("name", None)):
name = ""
graph = gv.Digraph(name, strict=False)
@@ -197,10 +198,6 @@ def _out_order_name(self, n: Node) -> str:
def _viz_node(self, node: Node, hugr: Hugr, graph: Digraph) -> None:
"""Render a (possibly nested) node to a graphviz graph."""
- # TODO: Port the CFG special-case rendering from guppy, and use it here
- # when a node is a CFG node.
- # See https://github.com/CQCL/guppylang/blob/7d5106cd59ad452046d0dfffd10eea9d9b617431/guppylang/hugr_builder/visualise.py#L250
-
meta = hugr[node].metadata
if len(meta) > 0:
data = "
" + "
".join(
@@ -266,10 +263,12 @@ def _viz_link(
color = self.palette.edge
case OrderKind():
color = self.palette.dark
- case ConstKind():
+ case ConstKind() | FunctionKind():
color = self.palette.const
- case _:
+ case CFKind():
color = self.palette.dark
+ case _:
+ assert_never(kind)
graph.edge(
self._out_port_name(src_port),
diff --git a/hugr-py/tests/__snapshots__/test_hugr_build.ambr b/hugr-py/tests/__snapshots__/test_hugr_build.ambr
index 52eb401356..dcd50978eb 100644
--- a/hugr-py/tests/__snapshots__/test_hugr_build.ambr
+++ b/hugr-py/tests/__snapshots__/test_hugr_build.ambr
@@ -1213,7 +1213,7 @@
> shape=plain]
color="#1CADE4" label="" margin=10
}
- 1:"out.0" -> 4:"in.1" [label="" arrowhead=none arrowsize=1.0 color=black fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
+ 1:"out.0" -> 4:"in.1" [label="" arrowhead=none arrowsize=1.0 color="#77CEEF" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
2:"out.0" -> 4:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
4:"out.0" -> 3:"in.0" [label=Qubit arrowhead=none arrowsize=1.0 color="#1CADE4" fontcolor=black fontname=monospace fontsize=9 penwidth=1.5]
}