From c5ea47fd77cfbdda5f32d651618ed69b97740e2e Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 17 Jun 2024 12:43:30 +0100 Subject: [PATCH] feat(hugr-py): CFG builder (#1192) Closes #1188 --- hugr-py/src/hugr/_cfg.py | 90 +++++++++++++++++++++++++ hugr-py/src/hugr/_dfg.py | 112 ++++++++++++++++++++++--------- hugr-py/src/hugr/_exceptions.py | 10 +++ hugr-py/src/hugr/_hugr.py | 27 ++++++-- hugr-py/src/hugr/_ops.py | 67 +++++++++++++++++- hugr-py/src/hugr/_tys.py | 1 + hugr-py/tests/test_cfg.py | 71 ++++++++++++++++++++ hugr-py/tests/test_hugr_build.py | 2 +- 8 files changed, 341 insertions(+), 39 deletions(-) create mode 100644 hugr-py/src/hugr/_cfg.py create mode 100644 hugr-py/tests/test_cfg.py diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py new file mode 100644 index 000000000..8ac058a23 --- /dev/null +++ b/hugr-py/src/hugr/_cfg.py @@ -0,0 +1,90 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Iterable, Sequence +from ._hugr import Hugr, Node, Wire +from ._dfg import DfBase, _from_base +from ._tys import FunctionType, TypeRow, Sum +from ._exceptions import NoSiblingAncestor, NotInSameCfg +import hugr._ops as ops + + +class Block(DfBase[ops.DataflowBlock]): + def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None: + self.set_outputs(branching, *other_outputs) + + def set_single_successor_outputs(self, *outputs: Wire) -> None: + # TODO requires constants + raise NotImplementedError + + def _wire_up(self, node: Node, ports: Iterable[Wire]): + for i, p in enumerate(ports): + src = p.out_port() + cfg_node = self.hugr[self.root].parent + assert cfg_node is not None + src_parent = self.hugr[src.node].parent + try: + self._wire_up_port(node, i, p) + except NoSiblingAncestor: + # note this just checks if there is a common CFG ancestor + # it does not check for valid dominance between basic blocks + # that is deferred to full HUGR validation. + while cfg_node != src_parent: + if src_parent is None or src_parent == self.hugr.root: + raise NotInSameCfg(src.node.idx, node.idx) + src_parent = self.hugr[src_parent].parent + + self.hugr.add_link(src, node.inp(i)) + + +@dataclass +class Cfg: + hugr: Hugr + root: Node + _entry_block: Block + exit: Node + + def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None: + root_op = ops.CFG(FunctionType(input=input_types, output=output_types)) + self.hugr = Hugr(root_op) + self.root = self.hugr.root + # to ensure entry is first child, add a dummy entry at the start + self._entry_block = _from_base( + Block, self.hugr.add_dfg(ops.DataflowBlock(input_types, [])) + ) + + self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.root) + + @property + def entry(self) -> Node: + return self._entry_block.root + + def _entry_op(self) -> ops.DataflowBlock: + dop = self.hugr[self.entry].op + assert isinstance(dop, ops.DataflowBlock) + return dop + + def add_entry(self, sum_rows: Sequence[TypeRow], other_outputs: TypeRow) -> Block: + # update entry block types + self._entry_op().sum_rows = list(sum_rows) + self._entry_op().other_outputs = other_outputs + self._entry_block._output_op().types = [Sum(list(sum_rows)), *other_outputs] + return self._entry_block + + def simple_entry(self, n_branches: int, other_outputs: TypeRow) -> Block: + return self.add_entry([[]] * n_branches, other_outputs) + + def add_block( + self, input_types: TypeRow, sum_rows: Sequence[TypeRow], other_outputs: TypeRow + ) -> Block: + new_block = self.hugr.add_dfg( + ops.DataflowBlock(input_types, list(sum_rows), other_outputs) + ) + return _from_base(Block, new_block) + + def simple_block( + self, input_types: TypeRow, n_branches: int, other_outputs: TypeRow + ) -> Block: + return self.add_block(input_types, [[]] * n_branches, other_outputs) + + def branch(self, src: Wire, dst: Node) -> None: + self.hugr.add_link(src.out_port(), dst.inp(0)) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 4c4ab3da7..17499a7b9 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -1,51 +1,59 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Sequence, Iterable +from typing import Iterable, TYPE_CHECKING, Generic, TypeVar, cast +import typing from ._hugr import Hugr, Node, Wire, OutPort -from ._ops import Op, Command, Input, Output, DFG +import hugr._ops as ops from ._exceptions import NoSiblingAncestor -from hugr._tys import FunctionType, Type +from hugr._tys import FunctionType, TypeRow + +if TYPE_CHECKING: + from ._cfg import Cfg + + +DP = TypeVar("DP", bound=ops.DfParentOp) @dataclass() -class Dfg: +class DfBase(Generic[DP]): hugr: Hugr root: Node input_node: Node output_node: Node - def __init__( - self, input_types: Sequence[Type], output_types: Sequence[Type] - ) -> None: - input_types = list(input_types) - output_types = list(output_types) - root_op = DFG(FunctionType(input=input_types, output=output_types)) + def __init__(self, root_op: DP) -> None: + input_types = root_op.input_types() + output_types = root_op.output_types() self.hugr = Hugr(root_op) self.root = self.hugr.root self.input_node = self.hugr.add_node( - Input(input_types), self.root, len(input_types) + ops.Input(input_types), self.root, len(input_types) ) - self.output_node = self.hugr.add_node(Output(output_types), self.root) + self.output_node = self.hugr.add_node(ops.Output(output_types), self.root) - @classmethod - def endo(cls, types: Sequence[Type]) -> Dfg: - return Dfg(types, types) - - def _input_op(self) -> Input: + def _input_op(self) -> ops.Input: dop = self.hugr[self.input_node].op - assert isinstance(dop, Input) + assert isinstance(dop, ops.Input) return dop + def _output_op(self) -> ops.Output: + dop = self.hugr[self.output_node].op + assert isinstance(dop, ops.Output) + return dop + + def root_op(self) -> DP: + return cast(DP, self.hugr[self.root].op) + def inputs(self) -> list[OutPort]: return [self.input_node.out(i) for i in range(len(self._input_op().types))] - def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node: + def add_op(self, op: ops.Op, /, *args: Wire, num_outs: int | None = None) -> Node: new_n = self.hugr.add_node(op, self.root, num_outs=num_outs) self._wire_up(new_n, args) return new_n - def add(self, com: Command) -> Node: + def add(self, com: ops.Command) -> Node: return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out) def insert_nested(self, dfg: Dfg, *args: Wire) -> Node: @@ -55,13 +63,30 @@ def insert_nested(self, dfg: Dfg, *args: Wire) -> Node: def add_nested( self, - input_types: Sequence[Type], - output_types: Sequence[Type], + input_types: TypeRow, + output_types: TypeRow, *args: Wire, ) -> Dfg: - dfg = self.hugr.add_dfg(input_types, output_types) + dfg = self.hugr.add_dfg( + ops.DFG(FunctionType(input=input_types, output=output_types)) + ) self._wire_up(dfg.root, args) - return dfg + return _from_base(Dfg, dfg) + + def add_cfg( + self, + input_types: TypeRow, + output_types: TypeRow, + *args: Wire, + ) -> Cfg: + cfg = self.hugr.add_cfg(input_types, output_types) + self._wire_up(cfg.root, args) + return cfg + + def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node: + mapping = self.hugr.insert_hugr(cfg.hugr, self.root) + self._wire_up(mapping[cfg.root], args) + return mapping[cfg.root] def set_outputs(self, *args: Wire) -> None: self._wire_up(self.output_node, args) @@ -72,13 +97,38 @@ def add_state_order(self, src: Node, dst: Node) -> None: def _wire_up(self, node: Node, ports: Iterable[Wire]): for i, p in enumerate(ports): - src = p.out_port() - node_ancestor = _ancestral_sibling(self.hugr, src.node, node) - if node_ancestor is None: - raise NoSiblingAncestor(src.node.idx, node.idx) - if node_ancestor != node: - self.add_state_order(src.node, node_ancestor) - self.hugr.add_link(src, node.inp(i)) + self._wire_up_port(node, i, p) + + def _wire_up_port(self, node: Node, offset: int, p: Wire): + src = p.out_port() + node_ancestor = _ancestral_sibling(self.hugr, src.node, node) + if node_ancestor is None: + raise NoSiblingAncestor(src.node.idx, node.idx) + if node_ancestor != node: + self.add_state_order(src.node, node_ancestor) + self.hugr.add_link(src, node.inp(offset)) + + +C = TypeVar("C", bound=DfBase) + + +def _from_base(cls: typing.Type[C], base: DfBase[DP]) -> C: + new = cls.__new__(cls) + new.hugr = base.hugr + new.root = base.root + new.input_node = base.input_node + new.output_node = base.output_node + return new + + +class Dfg(DfBase[ops.DFG]): + def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None: + root_op = ops.DFG(FunctionType(input=input_types, output=output_types)) + super().__init__(root_op) + + @classmethod + def endo(cls, types: TypeRow) -> Dfg: + return cls(types, types) def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None: diff --git a/hugr-py/src/hugr/_exceptions.py b/hugr-py/src/hugr/_exceptions.py index d59d99972..92ba7ceb0 100644 --- a/hugr-py/src/hugr/_exceptions.py +++ b/hugr-py/src/hugr/_exceptions.py @@ -11,5 +11,15 @@ def msg(self): return f"Source {self.src} has no sibling ancestor of target {self.tgt}, so cannot wire up." +@dataclass +class NotInSameCfg(Exception): + src: int + tgt: int + + @property + def msg(self): + return f"Source {self.src} is not in the same CFG as target {self.tgt}, so cannot wire up." + + class ParentBeforeChild(Exception): msg: str = "Parent node must be added before child node." diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index 45b3c180c..3bee26bd6 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -10,7 +10,6 @@ Iterable, Iterator, Protocol, - Sequence, TypeVar, cast, overload, @@ -19,7 +18,7 @@ from typing_extensions import Self from hugr._ops import Op -from hugr._tys import Type +from hugr._tys import TypeRow from hugr.serialization.ops import OpType as SerialOp from hugr.serialization.serial_hugr import SerialHugr from hugr.utils import BiMap @@ -27,7 +26,8 @@ from ._exceptions import ParentBeforeChild if TYPE_CHECKING: - from ._dfg import Dfg + from ._dfg import DfBase, DP + from ._cfg import Cfg class Direction(Enum): @@ -337,10 +337,10 @@ def insert_hugr(self, hugr: Hugr, parent: Node | None = None) -> dict[Node, Node ) return mapping - def add_dfg(self, input_types: Sequence[Type], output_types: Sequence[Type]) -> Dfg: - from ._dfg import Dfg + def add_dfg(self, root_op: DP) -> DfBase[DP]: + from ._dfg import DfBase - dfg = Dfg(input_types, output_types) + dfg = DfBase(root_op) mapping = self.insert_hugr(dfg.hugr, self.root) dfg.hugr = self dfg.input_node = mapping[dfg.input_node] @@ -348,6 +348,21 @@ def add_dfg(self, input_types: Sequence[Type], output_types: Sequence[Type]) -> dfg.root = mapping[dfg.root] return dfg + def add_cfg(self, input_types: TypeRow, output_types: TypeRow) -> Cfg: + from ._cfg import Cfg + + cfg = Cfg(input_types, output_types) + mapping = self.insert_hugr(cfg.hugr, self.root) + cfg.hugr = self + cfg._entry_block.root = mapping[cfg.entry] + cfg._entry_block.input_node = mapping[cfg._entry_block.input_node] + cfg._entry_block.output_node = mapping[cfg._entry_block.output_node] + cfg._entry_block.hugr = self + cfg.exit = mapping[cfg.exit] + cfg.root = mapping[cfg.root] + # TODO this is horrible + return cfg + def to_serial(self) -> SerialHugr: node_it = (node for node in self._nodes if node is not None) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 9724162dd..6b6986081 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -121,8 +121,13 @@ def __call__(self, tuple_: Wire) -> Command: return super().__call__(tuple_) +class DfParentOp(Op, Protocol): + def input_types(self) -> tys.TypeRow: ... + def output_types(self) -> tys.TypeRow: ... + + @dataclass() -class DFG(Op): +class DFG(DfParentOp): signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) @property @@ -134,3 +139,63 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG: parent=parent.idx, signature=self.signature.to_serial(), ) + + def input_types(self) -> tys.TypeRow: + return self.signature.input + + def output_types(self) -> tys.TypeRow: + return self.signature.output + + +@dataclass() +class CFG(Op): + signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) + + @property + def num_out(self) -> int | None: + return len(self.signature.output) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CFG: + return sops.CFG( + parent=parent.idx, + signature=self.signature.to_serial(), + ) + + +@dataclass +class DataflowBlock(DfParentOp): + inputs: tys.TypeRow + sum_rows: list[tys.TypeRow] + other_outputs: tys.TypeRow = field(default_factory=list) + extension_delta: tys.ExtensionSet = field(default_factory=list) + + @property + def num_out(self) -> int | None: + return len(self.sum_rows) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DataflowBlock: + return sops.DataflowBlock( + parent=parent.idx, + inputs=ser_it(self.inputs), + sum_rows=list(map(ser_it, self.sum_rows)), + other_outputs=ser_it(self.other_outputs), + extension_delta=self.extension_delta, + ) + + def input_types(self) -> tys.TypeRow: + return self.inputs + + def output_types(self) -> tys.TypeRow: + return [tys.Sum(self.sum_rows), *self.other_outputs] + + +@dataclass +class ExitBlock(Op): + cfg_outputs: tys.TypeRow + num_out: int | None = 0 + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.ExitBlock: + return sops.ExitBlock( + parent=parent.idx, + cfg_outputs=ser_it(self.cfg_outputs), + ) diff --git a/hugr-py/src/hugr/_tys.py b/hugr-py/src/hugr/_tys.py index 7cbd73f2f..6f199959c 100644 --- a/hugr-py/src/hugr/_tys.py +++ b/hugr-py/src/hugr/_tys.py @@ -269,3 +269,4 @@ def to_serial(self) -> stys.Qubit: Qubit = QubitDef() Bool = UnitSum(size=2) +Unit = UnitSum(size=1) diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py new file mode 100644 index 000000000..457334526 --- /dev/null +++ b/hugr-py/tests/test_cfg.py @@ -0,0 +1,71 @@ +from hugr._cfg import Cfg +import hugr._tys as tys +from hugr._dfg import Dfg +from .test_hugr_build import _validate, INT_T, DivMod + + +def build_basic_cfg(cfg: Cfg) -> None: + entry = cfg.simple_entry(1, [tys.Bool]) + + entry.set_block_outputs(*entry.inputs()) + cfg.branch(entry.root.out(0), cfg.exit) + + +def test_basic_cfg() -> None: + cfg = Cfg([tys.Unit, tys.Bool], [tys.Bool]) + build_basic_cfg(cfg) + _validate(cfg.hugr) + + +def test_branch() -> None: + cfg = Cfg([tys.Bool, tys.Unit, INT_T], [INT_T]) + entry = cfg.simple_entry(2, [tys.Unit, INT_T]) + entry.set_block_outputs(*entry.inputs()) + + middle_1 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T]) + middle_1.set_block_outputs(*middle_1.inputs()) + middle_2 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T]) + u, i = middle_2.inputs() + n = middle_2.add(DivMod(i, i)) + middle_2.set_block_outputs(u, n[0]) + + cfg.branch(entry.root.out(0), middle_1.root) + cfg.branch(entry.root.out(1), middle_2.root) + + cfg.branch(middle_1.root.out(0), cfg.exit) + cfg.branch(middle_2.root.out(0), cfg.exit) + + _validate(cfg.hugr) + + +def test_nested_cfg() -> None: + dfg = Dfg([tys.Unit, tys.Bool], [tys.Bool]) + + cfg = dfg.add_cfg([tys.Unit, tys.Bool], [tys.Bool], *dfg.inputs()) + + build_basic_cfg(cfg) + dfg.set_outputs(cfg.root) + + _validate(dfg.hugr) + + +def test_dom_edge() -> None: + cfg = Cfg([tys.Bool, tys.Unit, INT_T], [INT_T]) + entry = cfg.simple_entry(2, [INT_T]) + b, u, i = entry.inputs() + entry.set_block_outputs(b, i) + + # entry dominates both middles so Unit type can be used as inter-graph + # value between basic blocks + middle_1 = cfg.simple_block([INT_T], 1, [INT_T]) + middle_1.set_block_outputs(u, *middle_1.inputs()) + middle_2 = cfg.simple_block([INT_T], 1, [INT_T]) + middle_2.set_block_outputs(u, *middle_2.inputs()) + + cfg.branch(entry.root.out(0), middle_1.root) + cfg.branch(entry.root.out(1), middle_2.root) + + cfg.branch(middle_1.root.out(0), cfg.exit) + cfg.branch(middle_2.root.out(0), cfg.exit) + + _validate(cfg.hugr) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 37a3b624d..47dd9ce36 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -240,7 +240,7 @@ def test_build_inter_graph(): h.set_outputs(nested.root, b) - _validate(h.hugr, True) + _validate(h.hugr) assert _SubPort(h.input_node.out(-1)) in h.hugr._links assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order