Skip to content

Commit

Permalink
feat(py): Add node metadata (#1428)
Browse files Browse the repository at this point in the history
Adds a dictionary with metadata to the nodes.
`ToNode` now has a `metadata` property, so we can use
```python
d = Dfg(...)
d.metadata["key"] = 42

n = d.hugr.add_node(..., metadata={"something": "value"})
assert n.metadata["something"] == "value"
```

I chosed not to add a `metadata` argument to all the `add_{container}`
methods in DfgBase to avoid cluttering the API.

Closes #1319
  • Loading branch information
aborgna-q authored Aug 14, 2024
1 parent b89c08f commit b229be6
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 33 deletions.
13 changes: 9 additions & 4 deletions hugr-py/src/hugr/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass, field, replace
from typing import (
TYPE_CHECKING,
Any,
Generic,
TypeVar,
)
Expand Down Expand Up @@ -170,12 +171,15 @@ def inputs(self) -> list[OutPort]:
"""
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(self, op: ops.DataflowOp, /, *args: Wire) -> Node:
def add_op(
self, op: ops.DataflowOp, /, *args: Wire, metadata: dict[str, Any] | None = None
) -> Node:
"""Add a dataflow operation to the graph, wiring in input ports.
Args:
op: The operation to add.
args: The input wires to the operation.
metadata: Metadata to attach to the function definition. Defaults to None.
Returns:
The node holding the new operation.
Expand All @@ -185,17 +189,18 @@ def add_op(self, op: ops.DataflowOp, /, *args: Wire) -> Node:
>>> dfg.add_op(ops.Noop(), dfg.inputs()[0])
Node(3)
"""
new_n = self.hugr.add_node(op, self.parent_node)
new_n = self.hugr.add_node(op, self.parent_node, metadata=metadata)
self._wire_up(new_n, args)

return replace(new_n, _num_out_ports=op.num_out)

def add(self, com: ops.Command) -> Node:
def add(self, com: ops.Command, *, metadata: dict[str, Any] | None = None) -> Node:
"""Add a command (holding a dataflow operation and the incoming wires)
to the graph.
Args:
com: The command to add.
metadata: Metadata to attach to the function definition. Defaults to None.
Example:
>>> dfg = Dfg(tys.Bool)
Expand All @@ -212,7 +217,7 @@ def raise_no_ints():
wires = (
(w if not isinstance(w, int) else raise_no_ints()) for w in com.incoming
)
return self.add_op(com.op, *wires)
return self.add_op(com.op, *wires, metadata=metadata)

def extend(self, *coms: ops.Command) -> list[Node]:
"""Add a series of commands to the DFG.
Expand Down
5 changes: 5 additions & 0 deletions hugr-py/src/hugr/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,8 @@ def declare_function(self, name: str, signature: PolyFuncType) -> Node:
def add_alias_decl(self, name: str, bound: TypeBound) -> Node:
"""Add a type alias declaration."""
return self.hugr.add_node(ops.AliasDecl(name, bound), self.hugr.root)

@property
def metadata(self) -> dict[str, object]:
"""Metadata associated with this module."""
return self.hugr.root.metadata
79 changes: 58 additions & 21 deletions hugr-py/src/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataclasses import dataclass, field, replace
from typing import (
TYPE_CHECKING,
Any,
Generic,
Protocol,
TypeVar,
Expand Down Expand Up @@ -49,6 +50,7 @@ class NodeData:
_num_inps: int = field(default=0, repr=False)
_num_outs: int = field(default=0, repr=False)
children: list[Node] = field(default_factory=list, repr=False)
metadata: dict[str, Any] = field(default_factory=dict)

def to_serial(self, node: Node) -> SerialOp:
o = self.op.to_serial(self.parent if self.parent else node)
Expand Down Expand Up @@ -123,7 +125,11 @@ def __getitem__(self, key: ToNode) -> NodeData:
return n

def __iter__(self) -> Iterator[Node]:
return (Node(idx) for idx, data in enumerate(self._nodes) if data is not None)
return (
Node(idx, data.metadata)
for idx, data in enumerate(self._nodes)
if data is not None
)

def __len__(self) -> int:
return self.num_nodes()
Expand Down Expand Up @@ -160,17 +166,18 @@ def _add_node(
op: Op,
parent: ToNode | None = None,
num_outs: int | None = None,
metadata: dict[str, Any] | None = None,
) -> Node:
parent = parent.to_node() if parent else None
node_data = NodeData(op, parent)
node_data = NodeData(op, parent, metadata=metadata or {})

if self._free_nodes:
node = self._free_nodes.pop()
self._nodes[node.idx] = node_data
else:
node = Node(len(self._nodes))
node = Node(len(self._nodes), {})
self._nodes.append(node_data)
node = replace(node, _num_out_ports=num_outs)
node = replace(node, _num_out_ports=num_outs, _metadata=node_data.metadata)
if parent:
self[parent].children.append(node)
return node
Expand All @@ -194,26 +201,36 @@ def add_node(
op: Op,
parent: ToNode | None = None,
num_outs: int | None = None,
metadata: dict[str, Any] | None = None,
) -> Node:
"""Add a node to the HUGR.
Args:
op: Operation of the node.
parent: Parent node of added node. Defaults to HUGR root if None.
num_outs: Number of output ports expected for this node. Defaults to None.
metadata: A dictionary of metadata to associate with the node.
Defaults to None.
Returns:
Handle to the added node.
"""
parent = parent or self.root
return self._add_node(op, parent, num_outs)
return self._add_node(op, parent, num_outs, metadata)

def add_const(self, value: Value, parent: ToNode | None = None) -> Node:
def add_const(
self,
value: Value,
parent: ToNode | None = None,
metadata: dict[str, Any] | None = None,
) -> Node:
"""Add a constant node to the HUGR.
Args:
value: Value of the constant.
parent: Parent node of added node. Defaults to HUGR root if None.
metadata: A dictionary of metadata to associate with the node.
Defaults to None.
Returns:
Handle to the added node.
Expand All @@ -224,7 +241,7 @@ def add_const(self, value: Value, parent: ToNode | None = None) -> Node:
>>> h[n].op
Const(TRUE)
"""
return self.add_node(Const(value), parent)
return self.add_node(Const(value), parent, metadata=metadata)

def delete_node(self, node: ToNode) -> NodeData | None:
"""Delete a node from the HUGR.
Expand Down Expand Up @@ -254,6 +271,10 @@ def delete_node(self, node: ToNode) -> NodeData | None:
self._links.delete_left(_SubPort(node.out(offset)))

weight, self._nodes[node.idx] = self._nodes[node.idx], None

# Free up the metadata dictionary
node = replace(node, _metadata={})

self._free_nodes.append(node)
return weight

Expand Down Expand Up @@ -550,16 +571,18 @@ def insert_hugr(self, hugr: Hugr, parent: ToNode | None = None) -> dict[Node, No
"""
mapping: dict[Node, Node] = {}

for idx, node_data in enumerate(hugr._nodes):
if node_data is not None:
# relies on parents being inserted before any children
try:
node_parent = (
mapping[node_data.parent] if node_data.parent else parent
)
except KeyError as e:
raise ParentBeforeChild from e
mapping[Node(idx)] = self.add_node(node_data.op, node_parent)
for node, node_data in hugr.nodes():
# relies on parents being inserted before any children
try:
node_parent = mapping[node_data.parent] if node_data.parent else parent
except KeyError as e:
raise ParentBeforeChild from e
mapping[node] = self.add_node(
node_data.op,
node_parent,
num_outs=node_data._num_outs,
metadata=node_data.metadata,
)

for src, dst in hugr._links.items():
self.add_link(
Expand All @@ -581,8 +604,9 @@ def _serialize_link(

return SerialHugr(
# non contiguous indices will be erased
nodes=[node.to_serial(Node(idx)) for idx, node in enumerate(node_it)],
nodes=[node.to_serial(Node(idx, {})) for idx, node in enumerate(node_it)],
edges=[_serialize_link(link) for link in self._links.items()],
metadata=[node.metadata if node.metadata else None for node in node_it],
)

def _constrain_offset(self, p: P) -> PortOffset:
Expand Down Expand Up @@ -619,19 +643,32 @@ def from_serial(cls, serial: SerialHugr) -> Hugr:
hugr._links = BiMap()
hugr._free_nodes = []
hugr.root = Node(0)

def get_meta(idx: int) -> dict[str, Any]:
if not serial.metadata:
return {}
if idx < len(serial.metadata):
return serial.metadata[idx] or {}
return {}

for idx, serial_node in enumerate(serial.nodes):
node_meta = get_meta(idx)
parent: Node | None = Node(serial_node.root.parent)
if serial_node.root.parent == idx:
hugr.root = Node(idx)
hugr.root = Node(idx, _metadata=node_meta)
parent = None

serial_node.root.parent = -1
hugr._nodes.append(NodeData(serial_node.root.deserialize(), parent))
hugr._nodes.append(
NodeData(serial_node.root.deserialize(), parent, metadata=node_meta)
)

for (src_node, src_offset), (dst_node, dst_offset) in serial.edges:
if src_offset is None or dst_offset is None:
continue
hugr.add_link(
Node(src_node).out(src_offset), Node(dst_node).inp(dst_offset)
Node(src_node, _metadata=get_meta(src_node)).out(src_offset),
Node(dst_node, _metadata=get_meta(dst_node)).inp(dst_offset),
)

return hugr
Expand Down
10 changes: 9 additions & 1 deletion hugr-py/src/hugr/node_port.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ def port(self, offset: PortOffset, direction: Direction) -> InPort | OutPort:
else:
return self.out(offset)

@property
def metadata(self) -> dict[str, object]:
"""Metadata associated with this node."""
return self.to_node()._metadata


@dataclass(frozen=True, eq=True, order=True)
class Node(ToNode):
Expand All @@ -149,7 +154,10 @@ class Node(ToNode):
"""

idx: NodeIdx
_num_out_ports: int | None = field(default=None, compare=False)
_metadata: dict[str, object] = field(
repr=False, compare=False, default_factory=dict
)
_num_out_ports: int | None = field(default=None, compare=False, repr=False)

def _index(
self, index: PortOffset | slice | tuple[PortOffset, ...]
Expand Down
4 changes: 3 additions & 1 deletion hugr-py/src/hugr/tracked_dfg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Dfg builder that allows tracking a set of wires and appending operations by index."""

from collections.abc import Iterable
from typing import Any

from hugr import tys
from hugr.dfg import Dfg
Expand Down Expand Up @@ -123,7 +124,7 @@ def tracked_wire(self, index: int) -> Wire:
raise IndexError(msg)
return tracked

def add(self, com: Command) -> Node:
def add(self, com: Command, *, metadata: dict[str, Any] | None = None) -> Node:
"""Add a command to the DFG.
Overrides :meth:`Dfg.add <hugr.dfg.Dfg.add>` to allow Command inputs
Expand All @@ -138,6 +139,7 @@ def add(self, com: Command) -> Node:
Args:
com: Command to append.
metadata: Metadata to attach to the function definition. Defaults to None.
Returns:
The new node.
Expand Down
13 changes: 7 additions & 6 deletions hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,15 @@ def test_simple_id():
validate(simple_id().hugr)


def test_json_roundtrip():
hugr = simple_id().hugr
json = hugr.to_json()
def test_metadata():
h = Dfg(tys.Bool)
h.metadata["name"] = "simple_id"

hugr2 = Hugr.load_json(json)
json2 = hugr2.to_json()
(b,) = h.inputs()
b = h.add_op(Not, b, metadata={"name": "not"})

assert json2 == json
h.set_outputs(b)
validate(h.hugr)


def test_multiport():
Expand Down

0 comments on commit b229be6

Please sign in to comment.