Skip to content

Commit

Permalink
feat(hugr-py): context manager style nested building (#1276)
Browse files Browse the repository at this point in the history
Currently the context doesn't do anything much, just offers syntactical
nesting to make it easier to write HUGRs

The exception being Conditional, where we use the exit as an opportunity
to check all cases have been built.
This can be extended to other cases (e.g. ensure expected outputs have
been set when exiting a dfg).

Closes #1243
  • Loading branch information
ss2165 authored Jul 8, 2024
1 parent 8ccd3aa commit 6b32734
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 97 deletions.
31 changes: 20 additions & 11 deletions hugr-py/src/hugr/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

from __future__ import annotations

from contextlib import AbstractContextManager
from dataclasses import dataclass
from typing import TYPE_CHECKING

from typing_extensions import Self

from hugr import ops, val

from .dfg import _DfBase
Expand Down Expand Up @@ -47,7 +50,7 @@ def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:


@dataclass
class Cfg(ParentBuilder[ops.CFG]):
class Cfg(ParentBuilder[ops.CFG], AbstractContextManager):
"""Builder class for a HUGR control flow graph, with the HUGR root node
being a :class:`CFG <hugr.ops.CFG>`.
Expand Down Expand Up @@ -83,6 +86,12 @@ def _init_impl(self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow) -> None:

self.exit = self.hugr.add_node(ops.ExitBlock(), self.parent_node)

def __enter__(self) -> Self:
return self

def __exit__(self, *args) -> None:
return None

@classmethod
def new_nested(
cls,
Expand Down Expand Up @@ -158,8 +167,8 @@ def add_block(self, *input_types: Type) -> Block:
Examples:
>>> cfg = Cfg(tys.Bool)
>>> b = cfg.add_block(tys.Unit)
>>> b.set_single_succ_outputs(*b.inputs())
>>> with cfg.add_block(tys.Unit) as b:\
b.set_single_succ_outputs(*b.inputs())
"""
new_block = Block.new_nested(
ops.DataflowBlock(list(input_types)),
Expand All @@ -183,10 +192,10 @@ def add_successor(self, pred: Wire) -> Block:
Examples:
>>> cfg = Cfg(tys.Bool)
>>> entry = cfg.add_entry()
>>> entry.set_single_succ_outputs()
>>> b = cfg.add_successor(entry[0])
>>> b.set_single_succ_outputs(*b.inputs())
>>> with cfg.add_entry() as entry:\
entry.set_single_succ_outputs()
>>> with cfg.add_successor(entry[0]) as b:\
b.set_single_succ_outputs(*b.inputs())
"""
b = self.add_block(*self._nth_outputs(pred))

Expand All @@ -207,8 +216,8 @@ def branch(self, src: Wire, dst: ToNode) -> None:
Examples:
>>> cfg = Cfg(tys.Bool)
>>> entry = cfg.add_entry()
>>> entry.set_single_succ_outputs()
>>> with cfg.add_entry() as entry:\
entry.set_single_succ_outputs()
>>> b = cfg.add_block(tys.Unit)
>>> cfg.branch(entry[0], b)
"""
Expand All @@ -226,8 +235,8 @@ def branch_exit(self, src: Wire) -> None:
Examples:
>>> cfg = Cfg(tys.Bool)
>>> entry = cfg.add_entry()
>>> entry.set_single_succ_outputs()
>>> with cfg.add_entry() as entry:\
entry.set_single_succ_outputs()
>>> cfg.branch_exit(entry[0])
"""
src = src.out_port()
Expand Down
34 changes: 24 additions & 10 deletions hugr-py/src/hugr/cond_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

from __future__ import annotations

from contextlib import AbstractContextManager
from dataclasses import dataclass
from typing import TYPE_CHECKING

from typing_extensions import Self

from hugr import ops

from .dfg import _DfBase
Expand Down Expand Up @@ -47,6 +50,11 @@ def _parent_conditional(self) -> Conditional:
raise ConditionalError(msg)
return self._parent_cond

@property
def conditional_node(self) -> Node:
"""The node that represents the parent conditional."""
return self._parent_conditional().parent_node


class If(_IfElse):
"""Build the 'if' branch of a conditional branching on a boolean value.
Expand All @@ -59,7 +67,7 @@ class If(_IfElse):
>>> if_.set_outputs(if_.input_node[0])
>>> else_= if_.add_else()
>>> else_.set_outputs(else_.input_node[0])
>>> dfg.hugr[else_.finish()].op
>>> dfg.hugr[else_.conditional_node].op
Conditional(sum_ty=Bool, other_inputs=[Qubit])
"""

Expand All @@ -75,16 +83,13 @@ class Else(_IfElse):
"""

def finish(self) -> Node:
"""Finish building the if/else.
Returns:
The node that represents the parent conditional.
"""
return self._parent_conditional().parent_node
"""Deprecated, use `conditional_node` property."""
# TODO remove in 0.4.0
return self.conditional_node # pragma: no cover


@dataclass
class Conditional(ParentBuilder[ops.Conditional]):
class Conditional(ParentBuilder[ops.Conditional], AbstractContextManager):
"""Build a conditional branching on a sum type.
Args:
Expand All @@ -111,6 +116,15 @@ def _init_impl(self: Conditional, hugr: Hugr, root: Node, n_cases: int) -> None:
self.parent_node = root
self.cases = {i: None for i in range(n_cases)}

def __enter__(self) -> Self:
return self

def __exit__(self, *args) -> None:
if any(c is None for c in self.cases.values()):
msg = "All cases must be added before exiting context."
raise ConditionalError(msg)
return None

@classmethod
def new_nested(
cls,
Expand Down Expand Up @@ -164,8 +178,8 @@ def add_case(self, case_id: int) -> Case:
Examples:
>>> cond = Conditional(tys.Bool, [tys.Qubit])
>>> case = cond.add_case(0)
>>> case.set_outputs(*case.inputs())
>>> with cond.add_case(0) as case:\
case.set_outputs(*case.inputs())
"""
if case_id not in self.cases:
msg = f"Case {case_id} out of possible range."
Expand Down
17 changes: 12 additions & 5 deletions hugr-py/src/hugr/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from contextlib import AbstractContextManager
from dataclasses import dataclass, field, replace
from typing import (
TYPE_CHECKING,
Expand All @@ -27,7 +28,7 @@


@dataclass()
class _DfBase(ParentBuilder[DP]):
class _DfBase(ParentBuilder[DP], AbstractContextManager):
"""Base class for dataflow graph builders.
Args:
Expand Down Expand Up @@ -56,6 +57,12 @@ def _init_io_nodes(self, parent_op: DP):
)
self.output_node = self.hugr.add_node(ops.Output(), self.parent_node)

def __enter__(self) -> Self:
return self

def __exit__(self, *args) -> None:
return None

@classmethod
def new_nested(
cls, parent_op: DP, hugr: Hugr, parent: ToNode | None = None
Expand Down Expand Up @@ -179,8 +186,8 @@ def add_nested(
Example:
>>> dfg = Dfg(tys.Bool)
>>> dfg2 = dfg.add_nested(dfg.inputs()[0])
>>> dfg2.parent_node
>>> with dfg.add_nested(dfg.inputs()[0]) as dfg2:\
dfg2.parent_node
Node(3)
"""
from .dfg import Dfg
Expand All @@ -207,8 +214,8 @@ def add_cfg(
Example:
>>> dfg = Dfg(tys.Bool)
>>> cfg = dfg.add_cfg(dfg.inputs()[0])
>>> cfg.parent_op
>>> with dfg.add_cfg(dfg.inputs()[0]) as cfg:\
cfg.parent_op
CFG(inputs=[Bool])
"""
from .cfg import Cfg
Expand Down
45 changes: 22 additions & 23 deletions hugr-py/tests/test_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@


def build_basic_cfg(cfg: Cfg) -> None:
entry = cfg.add_entry()

entry.set_single_succ_outputs(*entry.inputs())
with cfg.add_entry() as entry:
entry.set_single_succ_outputs(*entry.inputs())
cfg.branch(entry[0], cfg.exit)


Expand Down Expand Up @@ -49,16 +48,17 @@ def test_nested_cfg() -> None:

def test_dom_edge() -> None:
cfg = Cfg(tys.Bool, tys.Unit, INT_T)
entry = cfg.add_entry()
b, u, i = entry.inputs()
entry.set_block_outputs(b, i)
with cfg.add_entry() as entry:
b, u, i = entry.inputs()
entry.set_block_outputs(b, i)

# entry dominates both middles so Unit type can be used as inter-graph
# value between basic blocks
middle_1 = cfg.add_successor(entry[0])
middle_1.set_block_outputs(u, *middle_1.inputs())
middle_2 = cfg.add_successor(entry[1])
middle_2.set_block_outputs(u, *middle_2.inputs())
with cfg.add_successor(entry[0]) as middle_1:
middle_1.set_block_outputs(u, *middle_1.inputs())

with cfg.add_successor(entry[1]) as middle_2:
middle_2.set_block_outputs(u, *middle_2.inputs())

cfg.branch_exit(middle_1[0])
cfg.branch_exit(middle_2[0])
Expand All @@ -68,22 +68,21 @@ def test_dom_edge() -> None:

def test_asymm_types() -> None:
# test different types going to entry block's susccessors
cfg = Cfg()
entry = cfg.add_entry()

int_load = entry.load(IntVal(34))
with Cfg() as cfg:
with cfg.add_entry() as entry:
int_load = entry.load(IntVal(34))

sum_ty = tys.Sum([[INT_T], [tys.Bool]])
tagged_int = entry.add(ops.Tag(0, sum_ty)(int_load))
entry.set_block_outputs(tagged_int)
sum_ty = tys.Sum([[INT_T], [tys.Bool]])
tagged_int = entry.add(ops.Tag(0, sum_ty)(int_load))
entry.set_block_outputs(tagged_int)

middle = cfg.add_successor(entry[0])
# discard the int and return the bool from entry
middle.set_single_succ_outputs(middle.load(val.TRUE))
with cfg.add_successor(entry[0]) as middle:
# discard the int and return the bool from entry
middle.set_single_succ_outputs(middle.load(val.TRUE))

# middle expects an int and exit expects a bool
cfg.branch_exit(entry[1])
cfg.branch_exit(middle[0])
# middle expects an int and exit expects a bool
cfg.branch_exit(entry[1])
cfg.branch_exit(middle[0])

validate(cfg.hugr)

Expand Down
Loading

0 comments on commit 6b32734

Please sign in to comment.