Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): CFG builder #1192

Merged
merged 6 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Sequence
from ._hugr import Hugr, Node, Wire
from ._dfg import DfBase, _from_base
from ._tys import FunctionType, TypeRow, Sum
from ._exceptions import NoSiblingAncestor, NotInSameCfg
import hugr._ops as ops


class Block(DfBase[ops.DataflowBlock]):
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None:
self.set_outputs(branching, *other_outputs)

def set_single_successor_outputs(self, *outputs: Wire) -> None:
# TODO requires constants
raise NotImplementedError

Check warning on line 17 in hugr-py/src/hugr/_cfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_cfg.py#L17

Added line #L17 was not covered by tests

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
src = p.out_port()
cfg_node = self.hugr[self.root].parent
assert cfg_node is not None
src_parent = self.hugr[src.node].parent
try:
self._wire_up_port(node, i, p)
except NoSiblingAncestor:
# note this just checks if there is a common CFG ancestor
# it does not check for valid dominance between basic blocks
# that is deferred to full HUGR validation.
while cfg_node != src_parent:
if src_parent is None or src_parent == self.hugr.root:
raise NotInSameCfg(src.node.idx, node.idx)

Check warning on line 33 in hugr-py/src/hugr/_cfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_cfg.py#L33

Added line #L33 was not covered by tests
src_parent = self.hugr[src_parent].parent

self.hugr.add_link(src, node.inp(i))


@dataclass
class Cfg:
hugr: Hugr
root: Node
_entry_block: Block
exit: Node

def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None:
root_op = ops.CFG(FunctionType(input=input_types, output=output_types))
self.hugr = Hugr(root_op)
self.root = self.hugr.root
# to ensure entry is first child, add a dummy entry at the start
self._entry_block = _from_base(
Block, self.hugr.add_dfg(ops.DataflowBlock(input_types, []))
)

self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.root)

@property
def entry(self) -> Node:
return self._entry_block.root

def _entry_op(self) -> ops.DataflowBlock:
dop = self.hugr[self.entry].op
assert isinstance(dop, ops.DataflowBlock)
return dop

def add_entry(self, sum_rows: Sequence[TypeRow], other_outputs: TypeRow) -> Block:
# update entry block types
self._entry_op().sum_rows = list(sum_rows)
self._entry_op().other_outputs = other_outputs
self._entry_block._output_op().types = [Sum(list(sum_rows)), *other_outputs]
return self._entry_block

def simple_entry(self, n_branches: int, other_outputs: TypeRow) -> Block:
return self.add_entry([[]] * n_branches, other_outputs)

def add_block(
self, input_types: TypeRow, sum_rows: Sequence[TypeRow], other_outputs: TypeRow
) -> Block:
new_block = self.hugr.add_dfg(
ops.DataflowBlock(input_types, list(sum_rows), other_outputs)
)
return _from_base(Block, new_block)

def simple_block(
self, input_types: TypeRow, n_branches: int, other_outputs: TypeRow
) -> Block:
return self.add_block(input_types, [[]] * n_branches, other_outputs)

def branch(self, src: Wire, dst: Node) -> None:
self.hugr.add_link(src.out_port(), dst.inp(0))
112 changes: 81 additions & 31 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,59 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence, Iterable
from typing import Iterable, TYPE_CHECKING, Generic, TypeVar, cast
import typing
from ._hugr import Hugr, Node, Wire, OutPort

from ._ops import Op, Command, Input, Output, DFG
import hugr._ops as ops
from ._exceptions import NoSiblingAncestor
from hugr._tys import FunctionType, Type
from hugr._tys import FunctionType, TypeRow

if TYPE_CHECKING:
from ._cfg import Cfg

Check warning on line 12 in hugr-py/src/hugr/_dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_dfg.py#L12

Added line #L12 was not covered by tests


DP = TypeVar("DP", bound=ops.DfParentOp)


@dataclass()
class Dfg:
class DfBase(Generic[DP]):
hugr: Hugr
root: Node
input_node: Node
output_node: Node

def __init__(
self, input_types: Sequence[Type], output_types: Sequence[Type]
) -> None:
input_types = list(input_types)
output_types = list(output_types)
root_op = DFG(FunctionType(input=input_types, output=output_types))
def __init__(self, root_op: DP) -> None:
input_types = root_op.input_types()
output_types = root_op.output_types()
self.hugr = Hugr(root_op)
self.root = self.hugr.root
self.input_node = self.hugr.add_node(
Input(input_types), self.root, len(input_types)
ops.Input(input_types), self.root, len(input_types)
)
self.output_node = self.hugr.add_node(Output(output_types), self.root)
self.output_node = self.hugr.add_node(ops.Output(output_types), self.root)

@classmethod
def endo(cls, types: Sequence[Type]) -> Dfg:
return Dfg(types, types)

def _input_op(self) -> Input:
def _input_op(self) -> ops.Input:
dop = self.hugr[self.input_node].op
assert isinstance(dop, Input)
assert isinstance(dop, ops.Input)
return dop

def _output_op(self) -> ops.Output:
dop = self.hugr[self.output_node].op
assert isinstance(dop, ops.Output)
return dop

def root_op(self) -> DP:
return cast(DP, self.hugr[self.root].op)

Check warning on line 46 in hugr-py/src/hugr/_dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_dfg.py#L46

Added line #L46 was not covered by tests

def inputs(self) -> list[OutPort]:
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node:
def add_op(self, op: ops.Op, /, *args: Wire, num_outs: int | None = None) -> Node:
new_n = self.hugr.add_node(op, self.root, num_outs=num_outs)
self._wire_up(new_n, args)
return new_n

def add(self, com: Command) -> Node:
def add(self, com: ops.Command) -> Node:
return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out)

def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
Expand All @@ -55,13 +63,30 @@

def add_nested(
self,
input_types: Sequence[Type],
output_types: Sequence[Type],
input_types: TypeRow,
output_types: TypeRow,
*args: Wire,
) -> Dfg:
dfg = self.hugr.add_dfg(input_types, output_types)
dfg = self.hugr.add_dfg(
ops.DFG(FunctionType(input=input_types, output=output_types))
)
self._wire_up(dfg.root, args)
return dfg
return _from_base(Dfg, dfg)

def add_cfg(
self,
input_types: TypeRow,
output_types: TypeRow,
*args: Wire,
) -> Cfg:
cfg = self.hugr.add_cfg(input_types, output_types)
self._wire_up(cfg.root, args)
return cfg

def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(cfg.hugr, self.root)
self._wire_up(mapping[cfg.root], args)
return mapping[cfg.root]

Check warning on line 89 in hugr-py/src/hugr/_dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_dfg.py#L87-L89

Added lines #L87 - L89 were not covered by tests

def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)
Expand All @@ -72,13 +97,38 @@

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
src = p.out_port()
node_ancestor = _ancestral_sibling(self.hugr, src.node, node)
if node_ancestor is None:
raise NoSiblingAncestor(src.node.idx, node.idx)
if node_ancestor != node:
self.add_state_order(src.node, node_ancestor)
self.hugr.add_link(src, node.inp(i))
self._wire_up_port(node, i, p)

def _wire_up_port(self, node: Node, offset: int, p: Wire):
src = p.out_port()
node_ancestor = _ancestral_sibling(self.hugr, src.node, node)
if node_ancestor is None:
raise NoSiblingAncestor(src.node.idx, node.idx)
if node_ancestor != node:
self.add_state_order(src.node, node_ancestor)
self.hugr.add_link(src, node.inp(offset))


C = TypeVar("C", bound=DfBase)


def _from_base(cls: typing.Type[C], base: DfBase[DP]) -> C:
new = cls.__new__(cls)
new.hugr = base.hugr
new.root = base.root
new.input_node = base.input_node
new.output_node = base.output_node
return new


class Dfg(DfBase[ops.DFG]):
def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None:
root_op = ops.DFG(FunctionType(input=input_types, output=output_types))
super().__init__(root_op)

@classmethod
def endo(cls, types: TypeRow) -> Dfg:
return cls(types, types)


def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None:
Expand Down
10 changes: 10 additions & 0 deletions hugr-py/src/hugr/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,15 @@
return f"Source {self.src} has no sibling ancestor of target {self.tgt}, so cannot wire up."


@dataclass
class NotInSameCfg(Exception):
src: int
tgt: int

@property
def msg(self):
return f"Source {self.src} is not in the same CFG as target {self.tgt}, so cannot wire up."

Check warning on line 21 in hugr-py/src/hugr/_exceptions.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_exceptions.py#L21

Added line #L21 was not covered by tests


class ParentBeforeChild(Exception):
msg: str = "Parent node must be added before child node."
27 changes: 21 additions & 6 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
Iterable,
Iterator,
Protocol,
Sequence,
TypeVar,
cast,
overload,
Expand All @@ -19,15 +18,16 @@
from typing_extensions import Self

from hugr._ops import Op
from hugr._tys import Type
from hugr._tys import TypeRow
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.utils import BiMap

from ._exceptions import ParentBeforeChild

if TYPE_CHECKING:
from ._dfg import Dfg
from ._dfg import DfBase, DP
from ._cfg import Cfg

Check warning on line 30 in hugr-py/src/hugr/_hugr.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_hugr.py#L29-L30

Added lines #L29 - L30 were not covered by tests


class Direction(Enum):
Expand Down Expand Up @@ -337,17 +337,32 @@
)
return mapping

def add_dfg(self, input_types: Sequence[Type], output_types: Sequence[Type]) -> Dfg:
from ._dfg import Dfg
def add_dfg(self, root_op: DP) -> DfBase[DP]:
from ._dfg import DfBase

dfg = Dfg(input_types, output_types)
dfg = DfBase(root_op)
mapping = self.insert_hugr(dfg.hugr, self.root)
dfg.hugr = self
dfg.input_node = mapping[dfg.input_node]
dfg.output_node = mapping[dfg.output_node]
dfg.root = mapping[dfg.root]
return dfg

def add_cfg(self, input_types: TypeRow, output_types: TypeRow) -> Cfg:
from ._cfg import Cfg

cfg = Cfg(input_types, output_types)
mapping = self.insert_hugr(cfg.hugr, self.root)
cfg.hugr = self
cfg._entry_block.root = mapping[cfg.entry]
cfg._entry_block.input_node = mapping[cfg._entry_block.input_node]
cfg._entry_block.output_node = mapping[cfg._entry_block.output_node]
cfg._entry_block.hugr = self
cfg.exit = mapping[cfg.exit]
cfg.root = mapping[cfg.root]
# TODO this is horrible
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #1194

Comment on lines +354 to +363
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be nicer to add an optional hugr argument to Cfg.__init__? If supplied, the CFG is inserted into the given Hugr, otherwise a new one is created.

Alternatively, it could be a classmethod Cfg.new_nested(hugr, input_types, output_types) or something similar?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the class method idea, I'll try it out on #1194

return cfg

def to_serial(self) -> SerialHugr:
node_it = (node for node in self._nodes if node is not None)

Expand Down
Loading