diff --git a/hugr-py/src/hugr/dfg.py b/hugr-py/src/hugr/dfg.py index 8fa9e0082..76ca42b2a 100644 --- a/hugr-py/src/hugr/dfg.py +++ b/hugr-py/src/hugr/dfg.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field, replace from typing import ( TYPE_CHECKING, + Any, Generic, TypeVar, ) @@ -170,12 +171,15 @@ def inputs(self) -> list[OutPort]: """ return [self.input_node.out(i) for i in range(len(self._input_op().types))] - def add_op(self, op: ops.DataflowOp, /, *args: Wire) -> Node: + def add_op( + self, op: ops.DataflowOp, /, *args: Wire, metadata: dict[str, Any] | None = None + ) -> Node: """Add a dataflow operation to the graph, wiring in input ports. Args: op: The operation to add. args: The input wires to the operation. + metadata: Metadata to attach to the function definition. Defaults to None. Returns: The node holding the new operation. @@ -185,17 +189,18 @@ def add_op(self, op: ops.DataflowOp, /, *args: Wire) -> Node: >>> dfg.add_op(ops.Noop(), dfg.inputs()[0]) Node(3) """ - new_n = self.hugr.add_node(op, self.parent_node) + new_n = self.hugr.add_node(op, self.parent_node, metadata=metadata) self._wire_up(new_n, args) return replace(new_n, _num_out_ports=op.num_out) - def add(self, com: ops.Command) -> Node: + def add(self, com: ops.Command, *, metadata: dict[str, Any] | None = None) -> Node: """Add a command (holding a dataflow operation and the incoming wires) to the graph. Args: com: The command to add. + metadata: Metadata to attach to the function definition. Defaults to None. Example: >>> dfg = Dfg(tys.Bool) @@ -212,7 +217,7 @@ def raise_no_ints(): wires = ( (w if not isinstance(w, int) else raise_no_ints()) for w in com.incoming ) - return self.add_op(com.op, *wires) + return self.add_op(com.op, *wires, metadata=metadata) def extend(self, *coms: ops.Command) -> list[Node]: """Add a series of commands to the DFG. diff --git a/hugr-py/src/hugr/function.py b/hugr-py/src/hugr/function.py index 66be9c58f..e95042655 100644 --- a/hugr-py/src/hugr/function.py +++ b/hugr-py/src/hugr/function.py @@ -57,3 +57,8 @@ def declare_function(self, name: str, signature: PolyFuncType) -> Node: def add_alias_decl(self, name: str, bound: TypeBound) -> Node: """Add a type alias declaration.""" return self.hugr.add_node(ops.AliasDecl(name, bound), self.hugr.root) + + @property + def metadata(self) -> dict[str, object]: + """Metadata associated with this module.""" + return self.hugr.root.metadata diff --git a/hugr-py/src/hugr/hugr.py b/hugr-py/src/hugr/hugr.py index 3a2cb7a56..2873cb258 100644 --- a/hugr-py/src/hugr/hugr.py +++ b/hugr-py/src/hugr/hugr.py @@ -7,6 +7,7 @@ from dataclasses import dataclass, field, replace from typing import ( TYPE_CHECKING, + Any, Generic, Protocol, TypeVar, @@ -49,6 +50,7 @@ class NodeData: _num_inps: int = field(default=0, repr=False) _num_outs: int = field(default=0, repr=False) children: list[Node] = field(default_factory=list, repr=False) + metadata: dict[str, Any] = field(default_factory=dict) def to_serial(self, node: Node) -> SerialOp: o = self.op.to_serial(self.parent if self.parent else node) @@ -123,7 +125,11 @@ def __getitem__(self, key: ToNode) -> NodeData: return n def __iter__(self) -> Iterator[Node]: - return (Node(idx) for idx, data in enumerate(self._nodes) if data is not None) + return ( + Node(idx, data.metadata) + for idx, data in enumerate(self._nodes) + if data is not None + ) def __len__(self) -> int: return self.num_nodes() @@ -160,17 +166,18 @@ def _add_node( op: Op, parent: ToNode | None = None, num_outs: int | None = None, + metadata: dict[str, Any] | None = None, ) -> Node: parent = parent.to_node() if parent else None - node_data = NodeData(op, parent) + node_data = NodeData(op, parent, metadata=metadata or {}) if self._free_nodes: node = self._free_nodes.pop() self._nodes[node.idx] = node_data else: - node = Node(len(self._nodes)) + node = Node(len(self._nodes), {}) self._nodes.append(node_data) - node = replace(node, _num_out_ports=num_outs) + node = replace(node, _num_out_ports=num_outs, _metadata=node_data.metadata) if parent: self[parent].children.append(node) return node @@ -194,6 +201,7 @@ def add_node( op: Op, parent: ToNode | None = None, num_outs: int | None = None, + metadata: dict[str, Any] | None = None, ) -> Node: """Add a node to the HUGR. @@ -201,19 +209,28 @@ def add_node( op: Operation of the node. parent: Parent node of added node. Defaults to HUGR root if None. num_outs: Number of output ports expected for this node. Defaults to None. + metadata: A dictionary of metadata to associate with the node. + Defaults to None. Returns: Handle to the added node. """ parent = parent or self.root - return self._add_node(op, parent, num_outs) + return self._add_node(op, parent, num_outs, metadata) - def add_const(self, value: Value, parent: ToNode | None = None) -> Node: + def add_const( + self, + value: Value, + parent: ToNode | None = None, + metadata: dict[str, Any] | None = None, + ) -> Node: """Add a constant node to the HUGR. Args: value: Value of the constant. parent: Parent node of added node. Defaults to HUGR root if None. + metadata: A dictionary of metadata to associate with the node. + Defaults to None. Returns: Handle to the added node. @@ -224,7 +241,7 @@ def add_const(self, value: Value, parent: ToNode | None = None) -> Node: >>> h[n].op Const(TRUE) """ - return self.add_node(Const(value), parent) + return self.add_node(Const(value), parent, metadata=metadata) def delete_node(self, node: ToNode) -> NodeData | None: """Delete a node from the HUGR. @@ -254,6 +271,10 @@ def delete_node(self, node: ToNode) -> NodeData | None: self._links.delete_left(_SubPort(node.out(offset))) weight, self._nodes[node.idx] = self._nodes[node.idx], None + + # Free up the metadata dictionary + node = replace(node, _metadata={}) + self._free_nodes.append(node) return weight @@ -550,16 +571,18 @@ def insert_hugr(self, hugr: Hugr, parent: ToNode | None = None) -> dict[Node, No """ mapping: dict[Node, Node] = {} - for idx, node_data in enumerate(hugr._nodes): - if node_data is not None: - # relies on parents being inserted before any children - try: - node_parent = ( - mapping[node_data.parent] if node_data.parent else parent - ) - except KeyError as e: - raise ParentBeforeChild from e - mapping[Node(idx)] = self.add_node(node_data.op, node_parent) + for node, node_data in hugr.nodes(): + # relies on parents being inserted before any children + try: + node_parent = mapping[node_data.parent] if node_data.parent else parent + except KeyError as e: + raise ParentBeforeChild from e + mapping[node] = self.add_node( + node_data.op, + node_parent, + num_outs=node_data._num_outs, + metadata=node_data.metadata, + ) for src, dst in hugr._links.items(): self.add_link( @@ -581,8 +604,9 @@ def _serialize_link( return SerialHugr( # non contiguous indices will be erased - nodes=[node.to_serial(Node(idx)) for idx, node in enumerate(node_it)], + nodes=[node.to_serial(Node(idx, {})) for idx, node in enumerate(node_it)], edges=[_serialize_link(link) for link in self._links.items()], + metadata=[node.metadata if node.metadata else None for node in node_it], ) def _constrain_offset(self, p: P) -> PortOffset: @@ -619,19 +643,32 @@ def from_serial(cls, serial: SerialHugr) -> Hugr: hugr._links = BiMap() hugr._free_nodes = [] hugr.root = Node(0) + + def get_meta(idx: int) -> dict[str, Any]: + if not serial.metadata: + return {} + if idx < len(serial.metadata): + return serial.metadata[idx] or {} + return {} + for idx, serial_node in enumerate(serial.nodes): + node_meta = get_meta(idx) parent: Node | None = Node(serial_node.root.parent) if serial_node.root.parent == idx: - hugr.root = Node(idx) + hugr.root = Node(idx, _metadata=node_meta) parent = None + serial_node.root.parent = -1 - hugr._nodes.append(NodeData(serial_node.root.deserialize(), parent)) + hugr._nodes.append( + NodeData(serial_node.root.deserialize(), parent, metadata=node_meta) + ) for (src_node, src_offset), (dst_node, dst_offset) in serial.edges: if src_offset is None or dst_offset is None: continue hugr.add_link( - Node(src_node).out(src_offset), Node(dst_node).inp(dst_offset) + Node(src_node, _metadata=get_meta(src_node)).out(src_offset), + Node(dst_node, _metadata=get_meta(dst_node)).inp(dst_offset), ) return hugr diff --git a/hugr-py/src/hugr/node_port.py b/hugr-py/src/hugr/node_port.py index e558aaa79..0f268fe88 100644 --- a/hugr-py/src/hugr/node_port.py +++ b/hugr-py/src/hugr/node_port.py @@ -141,6 +141,11 @@ def port(self, offset: PortOffset, direction: Direction) -> InPort | OutPort: else: return self.out(offset) + @property + def metadata(self) -> dict[str, object]: + """Metadata associated with this node.""" + return self.to_node()._metadata + @dataclass(frozen=True, eq=True, order=True) class Node(ToNode): @@ -149,7 +154,10 @@ class Node(ToNode): """ idx: NodeIdx - _num_out_ports: int | None = field(default=None, compare=False) + _metadata: dict[str, object] = field( + repr=False, compare=False, default_factory=dict + ) + _num_out_ports: int | None = field(default=None, compare=False, repr=False) def _index( self, index: PortOffset | slice | tuple[PortOffset, ...] diff --git a/hugr-py/src/hugr/tracked_dfg.py b/hugr-py/src/hugr/tracked_dfg.py index 956325653..aa8d768bd 100644 --- a/hugr-py/src/hugr/tracked_dfg.py +++ b/hugr-py/src/hugr/tracked_dfg.py @@ -1,6 +1,7 @@ """Dfg builder that allows tracking a set of wires and appending operations by index.""" from collections.abc import Iterable +from typing import Any from hugr import tys from hugr.dfg import Dfg @@ -123,7 +124,7 @@ def tracked_wire(self, index: int) -> Wire: raise IndexError(msg) return tracked - def add(self, com: Command) -> Node: + def add(self, com: Command, *, metadata: dict[str, Any] | None = None) -> Node: """Add a command to the DFG. Overrides :meth:`Dfg.add ` to allow Command inputs @@ -138,6 +139,7 @@ def add(self, com: Command) -> Node: Args: com: Command to append. + metadata: Metadata to attach to the function definition. Defaults to None. Returns: The new node. diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 48f544360..8f82159cb 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -69,14 +69,15 @@ def test_simple_id(): validate(simple_id().hugr) -def test_json_roundtrip(): - hugr = simple_id().hugr - json = hugr.to_json() +def test_metadata(): + h = Dfg(tys.Bool) + h.metadata["name"] = "simple_id" - hugr2 = Hugr.load_json(json) - json2 = hugr2.to_json() + (b,) = h.inputs() + b = h.add_op(Not, b, metadata={"name": "not"}) - assert json2 == json + h.set_outputs(b) + validate(h.hugr) def test_multiport():