From 85db01431695bf1e10b356f9132ab9f61b4d315b Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 5 Jun 2024 14:24:17 +0100 Subject: [PATCH] feat(hugr-py): store children in node weight Could instead bit the bullet and iterate through every node for calculating the children of a node, but this implementation seems like a decent code complexity/runtime balance Closes #1159 --- hugr-py/src/hugr/_exceptions.py | 4 +++ hugr-py/src/hugr/_hugr.py | 58 ++++++++++++++++++++++---------- hugr-py/tests/test_hugr_build.py | 6 ++-- 3 files changed, 49 insertions(+), 19 deletions(-) diff --git a/hugr-py/src/hugr/_exceptions.py b/hugr-py/src/hugr/_exceptions.py index 3245af0cc..d59d99972 100644 --- a/hugr-py/src/hugr/_exceptions.py +++ b/hugr-py/src/hugr/_exceptions.py @@ -9,3 +9,7 @@ class NoSiblingAncestor(Exception): @property def msg(self): return f"Source {self.src} has no sibling ancestor of target {self.tgt}, so cannot wire up." + + +class ParentBeforeChild(Exception): + msg: str = "Parent node must be added before child node." diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index d42f2edf1..acc668191 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -1,28 +1,31 @@ from __future__ import annotations -from dataclasses import dataclass, field, replace + from collections.abc import Mapping +from dataclasses import dataclass, field, replace from enum import Enum from typing import ( + TYPE_CHECKING, + ClassVar, + Generic, Iterable, Iterator, - Sequence, Protocol, - Generic, + Sequence, TypeVar, cast, overload, - ClassVar, - TYPE_CHECKING, ) from typing_extensions import Self -from hugr.serialization.serial_hugr import SerialHugr +from hugr._ops import Op from hugr.serialization.ops import OpType as SerialOp +from hugr.serialization.serial_hugr import SerialHugr from hugr.serialization.tys import Type -from hugr._ops import Op from hugr.utils import BiMap +from ._exceptions import ParentBeforeChild + if TYPE_CHECKING: from ._dfg import Dfg @@ -111,7 +114,7 @@ class NodeData: parent: Node | None _num_inps: int = 0 _num_outs: int = 0 - # TODO children field? + children: list[Node] = field(default_factory=list) def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: o = self.op.to_serial(node, self.parent if self.parent else node, hugr) @@ -147,7 +150,7 @@ def __init__(self, root_op: Op) -> None: self._free_nodes = [] self._links = BiMap() self._nodes = [] - self.root = self.add_node(root_op) + self.root = self._add_node(root_op, None, 0) def __getitem__(self, key: Node) -> NodeData: try: @@ -164,7 +167,11 @@ def __iter__(self): def __len__(self) -> int: return self.num_nodes() - def add_node( + def children(self, node: Node | None = None) -> list[Node]: + node = node or self.root + return self[node].children + + def _add_node( self, op: Op, parent: Node | None = None, @@ -178,9 +185,24 @@ def add_node( else: node = Node(len(self._nodes)) self._nodes.append(node_data) - return replace(node, _num_out_ports=num_outs) + node = replace(node, _num_out_ports=num_outs) + if parent: + self[parent].children.append(node) + return node + + def add_node( + self, + op: Op, + parent: Node | None = None, + num_outs: int | None = None, + ) -> Node: + parent = parent or self.root + return self._add_node(op, parent, num_outs) def delete_node(self, node: Node) -> NodeData | None: + parent = self[node].parent + if parent: + self[parent].children.remove(node) for offset in range(self.num_in_ports(node)): self._links.delete_right(_SubPort(node.inp(offset))) for offset in range(self.num_out_ports(node)): @@ -299,12 +321,14 @@ def insert_hugr(self, hugr: Hugr, parent: Node | None = None) -> dict[Node, Node for idx, node_data in enumerate(hugr._nodes): if node_data is not None: - mapping[Node(idx)] = self.add_node(node_data.op, node_data.parent) - - for new_node in mapping.values(): - # update mapped parent - node_data = self[new_node] - node_data.parent = mapping[node_data.parent] if node_data.parent else parent + # relies on parents being inserted before any children + try: + node_parent = ( + mapping[node_data.parent] if node_data.parent else parent + ) + except KeyError as e: + raise ParentBeforeChild() from e + mapping[Node(idx)] = self.add_node(node_data.op, node_parent) for src, dst in hugr._links.items(): self.add_link( diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 52f7a2b07..e13c2198d 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -88,12 +88,14 @@ def test_stable_indices(): assert len(h) == 4 h.add_link(nodes[0].out(0), nodes[1].inp(0)) + assert h.children() == nodes assert h.num_outgoing(nodes[0]) == 1 assert h.num_incoming(nodes[1]) == 1 assert h.delete_node(nodes[1]) is not None assert h._nodes[nodes[1].idx] is None + assert nodes[1] not in h.children(h.root) assert len(h) == 3 assert len(h._nodes) == 4 @@ -204,7 +206,7 @@ def test_insert_nested(): (a,) = h.inputs() nested = h.insert_nested(h1, a) h.set_outputs(nested) - + assert len(h.hugr.children(nested)) == 3 _validate(h.hugr) @@ -219,7 +221,7 @@ def _nested_nop(dfg: Dfg): nested = h.add_nested([BOOL_T], [BOOL_T], a) _nested_nop(nested) - + assert len(h.hugr.children(nested.root)) == 3 h.set_outputs(nested.root) _validate(h.hugr)