Skip to content

Commit

Permalink
feat(hugr-py): store children in node weight (#1160)
Browse files Browse the repository at this point in the history
Could instead bite the bullet and iterate through every node for
calculating the children of a node, but this implementation seems like a
decent code complexity/runtime balance

Closes #1159
  • Loading branch information
ss2165 authored Jun 6, 2024
1 parent cf542b4 commit 1cdaeed
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 19 deletions.
4 changes: 4 additions & 0 deletions hugr-py/src/hugr/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ class NoSiblingAncestor(Exception):
@property
def msg(self):
return f"Source {self.src} has no sibling ancestor of target {self.tgt}, so cannot wire up."


class ParentBeforeChild(Exception):
msg: str = "Parent node must be added before child node."
58 changes: 41 additions & 17 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
from __future__ import annotations
from dataclasses import dataclass, field, replace

from collections.abc import Mapping
from dataclasses import dataclass, field, replace
from enum import Enum
from typing import (
TYPE_CHECKING,
ClassVar,
Generic,
Iterable,
Iterator,
Sequence,
Protocol,
Generic,
Sequence,
TypeVar,
cast,
overload,
ClassVar,
TYPE_CHECKING,
)

from typing_extensions import Self

from hugr.serialization.serial_hugr import SerialHugr
from hugr._ops import Op
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.serialization.tys import Type
from hugr._ops import Op
from hugr.utils import BiMap

from ._exceptions import ParentBeforeChild

if TYPE_CHECKING:
from ._dfg import Dfg

Expand Down Expand Up @@ -111,7 +114,7 @@ class NodeData:
parent: Node | None
_num_inps: int = 0
_num_outs: int = 0
# TODO children field?
children: list[Node] = field(default_factory=list)

def to_serial(self, node: Node, hugr: Hugr) -> SerialOp:
o = self.op.to_serial(node, self.parent if self.parent else node, hugr)
Expand Down Expand Up @@ -147,7 +150,7 @@ def __init__(self, root_op: Op) -> None:
self._free_nodes = []
self._links = BiMap()
self._nodes = []
self.root = self.add_node(root_op)
self.root = self._add_node(root_op, None, 0)

def __getitem__(self, key: Node) -> NodeData:
try:
Expand All @@ -164,7 +167,11 @@ def __iter__(self):
def __len__(self) -> int:
return self.num_nodes()

def add_node(
def children(self, node: Node | None = None) -> list[Node]:
node = node or self.root
return self[node].children

def _add_node(
self,
op: Op,
parent: Node | None = None,
Expand All @@ -178,9 +185,24 @@ def add_node(
else:
node = Node(len(self._nodes))
self._nodes.append(node_data)
return replace(node, _num_out_ports=num_outs)
node = replace(node, _num_out_ports=num_outs)
if parent:
self[parent].children.append(node)
return node

def add_node(
self,
op: Op,
parent: Node | None = None,
num_outs: int | None = None,
) -> Node:
parent = parent or self.root
return self._add_node(op, parent, num_outs)

def delete_node(self, node: Node) -> NodeData | None:
parent = self[node].parent
if parent:
self[parent].children.remove(node)
for offset in range(self.num_in_ports(node)):
self._links.delete_right(_SubPort(node.inp(offset)))
for offset in range(self.num_out_ports(node)):
Expand Down Expand Up @@ -299,12 +321,14 @@ def insert_hugr(self, hugr: Hugr, parent: Node | None = None) -> dict[Node, Node

for idx, node_data in enumerate(hugr._nodes):
if node_data is not None:
mapping[Node(idx)] = self.add_node(node_data.op, node_data.parent)

for new_node in mapping.values():
# update mapped parent
node_data = self[new_node]
node_data.parent = mapping[node_data.parent] if node_data.parent else parent
# relies on parents being inserted before any children
try:
node_parent = (
mapping[node_data.parent] if node_data.parent else parent
)
except KeyError as e:
raise ParentBeforeChild() from e
mapping[Node(idx)] = self.add_node(node_data.op, node_parent)

for src, dst in hugr._links.items():
self.add_link(
Expand Down
6 changes: 4 additions & 2 deletions hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,14 @@ def test_stable_indices():
assert len(h) == 4

h.add_link(nodes[0].out(0), nodes[1].inp(0))
assert h.children() == nodes

assert h.num_outgoing(nodes[0]) == 1
assert h.num_incoming(nodes[1]) == 1

assert h.delete_node(nodes[1]) is not None
assert h._nodes[nodes[1].idx] is None
assert nodes[1] not in h.children(h.root)

assert len(h) == 3
assert len(h._nodes) == 4
Expand Down Expand Up @@ -204,7 +206,7 @@ def test_insert_nested():
(a,) = h.inputs()
nested = h.insert_nested(h1, a)
h.set_outputs(nested)

assert len(h.hugr.children(nested)) == 3
_validate(h.hugr)


Expand All @@ -219,7 +221,7 @@ def _nested_nop(dfg: Dfg):
nested = h.add_nested([BOOL_T], [BOOL_T], a)

_nested_nop(nested)

assert len(h.hugr.children(nested.root)) == 3
h.set_outputs(nested.root)

_validate(h.hugr)
Expand Down

0 comments on commit 1cdaeed

Please sign in to comment.