diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py new file mode 100644 index 000000000..9a2452c4a --- /dev/null +++ b/hugr-py/src/hugr/_cfg.py @@ -0,0 +1,74 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Sequence +from ._hugr import Hugr, Node, Wire +from ._dfg import DfBase, _from_base +from ._tys import Type, FunctionType, TypeRow, Sum +import hugr._ops as ops + + +class Block(DfBase[ops.DataflowBlock]): + def block_outputs(self, branching: Wire, *other_outputs: Wire) -> None: + self.set_outputs(branching, *other_outputs) + + def single_successor_outputs(self, *outputs: Wire) -> None: + # TODO requires constants + raise NotImplementedError + + +@dataclass +class Cfg: + hugr: Hugr + root: Node + _entry_block: Block + exit: Node + + def __init__( + self, input_types: Sequence[Type], output_types: Sequence[Type] + ) -> None: + input_types = list(input_types) + output_types = list(output_types) + root_op = ops.CFG(FunctionType(input=input_types, output=output_types)) + self.hugr = Hugr(root_op) + self.root = self.hugr.root + # to ensure entry is first child, add a dummy entry at the start + self._entry_block = _from_base( + Block, self.hugr.add_dfg(ops.DataflowBlock(input_types, [])) + ) + + self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.root) + + @property + def entry(self) -> Node: + return self._entry_block.root + + def _entry_op(self) -> ops.DataflowBlock: + dop = self.hugr[self.entry].op + assert isinstance(dop, ops.DataflowBlock) + return dop + + def add_entry(self, sum_rows: Sequence[TypeRow], other_outputs: TypeRow) -> Block: + # update entry block types + self._entry_op().sum_rows = list(sum_rows) + self._entry_op().other_outputs = other_outputs + self._entry_block._output_op().types = [Sum(list(sum_rows)), *other_outputs] + return self._entry_block + + def simple_entry(self, n_branches: int, other_outputs: TypeRow) -> Block: + return self.add_entry([[]] * n_branches, other_outputs) + + def add_block( + self, input_types: TypeRow, sum_rows: Sequence[TypeRow], other_outputs: TypeRow + ) -> Block: + new_block = self.hugr.add_dfg( + ops.DataflowBlock(input_types, list(sum_rows), other_outputs) + ) + return _from_base(Block, new_block) + + def simple_block( + self, input_types: TypeRow, n_branches: int, other_outputs: TypeRow + ) -> Block: + return self.add_block(input_types, [[]] * n_branches, other_outputs) + + def branch(self, src: Wire, dst: Node) -> None: + self.hugr.add_link(src.out_port(), dst.inp(0)) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index b371e0ba4..54c5c6b58 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -73,6 +73,16 @@ def add_nested( self._wire_up(dfg.root, args) return _from_base(Dfg, dfg) + def add_cfg( + self, + input_types: Sequence[Type], + output_types: Sequence[Type], + *args: Wire, + ) -> Cfg: + cfg = self.hugr.add_cfg(input_types, output_types) + self._wire_up(cfg.root, args) + return cfg + def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node: mapping = self.hugr.insert_hugr(cfg.hugr, self.root) self._wire_up(mapping[cfg.root], args) diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index 40882d0bc..a35b87f4a 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -13,11 +13,13 @@ TypeVar, cast, overload, + Sequence, ) from typing_extensions import Self from hugr._ops import Op +from hugr._tys import Type from hugr.serialization.ops import OpType as SerialOp from hugr.serialization.serial_hugr import SerialHugr from hugr.utils import BiMap @@ -26,6 +28,7 @@ if TYPE_CHECKING: from ._dfg import DfBase, DP + from ._cfg import Cfg class Direction(Enum): @@ -346,6 +349,21 @@ def add_dfg(self, root_op: DP) -> DfBase[DP]: dfg.root = mapping[dfg.root] return dfg + def add_cfg(self, input_types: Sequence[Type], output_types: Sequence[Type]) -> Cfg: + from ._cfg import Cfg + + cfg = Cfg(input_types, output_types) + mapping = self.insert_hugr(cfg.hugr, self.root) + cfg.hugr = self + cfg._entry_block.root = mapping[cfg.entry] + cfg._entry_block.input_node = mapping[cfg._entry_block.input_node] + cfg._entry_block.output_node = mapping[cfg._entry_block.output_node] + cfg._entry_block.hugr = self + cfg.exit = mapping[cfg.exit] + cfg.root = mapping[cfg.root] + # TODO this is horrible + return cfg + def to_serial(self) -> SerialHugr: node_it = (node for node in self._nodes if node is not None) diff --git a/hugr-py/src/hugr/_tys.py b/hugr-py/src/hugr/_tys.py index 7cbd73f2f..6f199959c 100644 --- a/hugr-py/src/hugr/_tys.py +++ b/hugr-py/src/hugr/_tys.py @@ -269,3 +269,4 @@ def to_serial(self) -> stys.Qubit: Qubit = QubitDef() Bool = UnitSum(size=2) +Unit = UnitSum(size=1) diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py new file mode 100644 index 000000000..ce21186b3 --- /dev/null +++ b/hugr-py/tests/test_cfg.py @@ -0,0 +1,49 @@ +from hugr._cfg import Cfg +import hugr._tys as tys +from hugr._dfg import Dfg +from .test_hugr_build import _validate, INT_T, DivMod + + +def build_basic_cfg(cfg: Cfg) -> None: + entry = cfg.simple_entry(1, [tys.Bool]) + + entry.block_outputs(*entry.inputs()) + cfg.branch(entry.root.out(0), cfg.exit) + + +def test_basic_cfg() -> None: + cfg = Cfg([tys.Unit, tys.Bool], [tys.Bool]) + build_basic_cfg(cfg) + _validate(cfg.hugr) + + +def test_branch() -> None: + cfg = Cfg([tys.Bool, tys.Unit, INT_T], [INT_T]) + entry = cfg.simple_entry(2, [tys.Unit, INT_T]) + entry.block_outputs(*entry.inputs()) + + middle_1 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T]) + middle_1.block_outputs(*middle_1.inputs()) + middle_2 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T]) + u, i = middle_2.inputs() + n = middle_2.add(DivMod(i, i)) + middle_2.block_outputs(u, n[0]) + + cfg.branch(entry.root.out(0), middle_1.root) + cfg.branch(entry.root.out(1), middle_2.root) + + cfg.branch(middle_1.root.out(0), cfg.exit) + cfg.branch(middle_2.root.out(0), cfg.exit) + + _validate(cfg.hugr) + + +def test_nested_cfg() -> None: + dfg = Dfg([tys.Unit, tys.Bool], [tys.Bool]) + + cfg = dfg.add_cfg([tys.Unit, tys.Bool], [tys.Bool], *dfg.inputs()) + + build_basic_cfg(cfg) + dfg.set_outputs(cfg.root) + + _validate(dfg.hugr, True) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 37a3b624d..47dd9ce36 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -240,7 +240,7 @@ def test_build_inter_graph(): h.set_outputs(nested.root, b) - _validate(h.hugr, True) + _validate(h.hugr) assert _SubPort(h.input_node.out(-1)) in h.hugr._links assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order