Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): add builders for Conditional and TailLoop #1210

Merged
merged 12 commits into from
Jun 21, 2024
16 changes: 8 additions & 8 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass, replace
from dataclasses import dataclass

import hugr._ops as ops

from ._dfg import _DfBase
from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._tys import FunctionType, TypeRow, Type
from ._tys import TypeRow, Type
import hugr._val as val


Expand All @@ -16,7 +16,7 @@ def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None:
self.set_outputs(branching, *other_outputs)

def set_single_succ_outputs(self, *outputs: Wire) -> None:
u = self.add_load_const(val.Unit)
u = self.load(val.Unit)
self.set_outputs(u, *outputs)

def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:
Expand Down Expand Up @@ -47,7 +47,7 @@ class Cfg(ParentBuilder[ops.CFG]):
exit: Node

def __init__(self, input_types: TypeRow) -> None:
root_op = ops.CFG(FunctionType(input=input_types, output=[]))
root_op = ops.CFG(inputs=input_types)
hugr = Hugr(root_op)
self._init_impl(hugr, hugr.root, input_types)

Expand All @@ -68,7 +68,7 @@ def new_nested(
) -> Cfg:
new = cls.__new__(cls)
root = hugr.add_node(
ops.CFG(FunctionType(input=input_types, output=[])),
ops.CFG(inputs=input_types),
parent or hugr.root,
)
new._init_impl(hugr, root, input_types)
Expand Down Expand Up @@ -97,6 +97,8 @@ def add_block(self, input_types: TypeRow) -> Block:
)
return new_block

# TODO insert_block
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO now or later?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

later - I'm not convinced it's a common enough use case to do immediately


def add_successor(self, pred: Wire) -> Block:
b = self.add_block(self._nth_outputs(pred))

Expand Down Expand Up @@ -125,6 +127,4 @@ def branch_exit(self, src: Wire) -> None:
raise MismatchedExit(src.node.idx)
else:
self._exit_op._cfg_outputs = out_types
self.parent_op.signature = replace(
self.parent_op.signature, output=out_types
)
self.parent_op._outputs = out_types
99 changes: 99 additions & 0 deletions hugr-py/src/hugr/_cond_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import annotations

from dataclasses import dataclass

import hugr._ops as ops

from ._dfg import _DfBase
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._tys import Sum, TypeRow


class Case(_DfBase[ops.Case]):
_parent_cond: Conditional | None = None

def set_outputs(self, *outputs: Wire) -> None:
super().set_outputs(*outputs)
if self._parent_cond is not None:
self._parent_cond._update_outputs(self._wire_types(outputs))


@dataclass
class _IfElse(Case):
def __init__(self, case: Case) -> None:
self.hugr = case.hugr
self.parent_node = case.parent_node
self.input_node = case.input_node
self.output_node = case.output_node
self._parent_cond = case._parent_cond

def _parent_conditional(self) -> Conditional:
assert self._parent_cond is not None, "If must have a parent conditional."
return self._parent_cond


class If(_IfElse):
def add_else(self) -> Else:
return Else(self._parent_conditional().add_case(0))


class Else(_IfElse):
def finish(self) -> Node:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if it were possible to finish an _IfElse with a trivial Else without having to explicitly add the Else.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I thought about this but unfortunately the form of the "trivial else" depends on linearity of types, so instead of making too many assumptions decided to leave it explicit for now.

return self._parent_conditional().parent_node


@dataclass
class Conditional(ParentBuilder[ops.Conditional]):
cases: dict[int, Node | None]

def __init__(self, sum_ty: Sum, other_inputs: TypeRow) -> None:
root_op = ops.Conditional(sum_ty, other_inputs)
hugr = Hugr(root_op)
self._init_impl(hugr, hugr.root, len(sum_ty.variant_rows))

def _init_impl(self: Conditional, hugr: Hugr, root: Node, n_cases: int) -> None:
self.hugr = hugr
self.parent_node = root
self.cases = {i: None for i in range(n_cases)}

@classmethod
def new_nested(
cls,
sum_ty: Sum,
other_inputs: TypeRow,
hugr: Hugr,
parent: ToNode | None = None,
) -> Conditional:
new = cls.__new__(cls)
root = hugr.add_node(
ops.Conditional(sum_ty, other_inputs),
parent or hugr.root,
)
new._init_impl(hugr, root, len(sum_ty.variant_rows))
return new

def _update_outputs(self, outputs: TypeRow) -> None:
if self.parent_op._outputs is None:
self.parent_op._outputs = outputs
else:
assert outputs == self.parent_op._outputs, "Mismatched case outputs."

def add_case(self, case_id: int) -> Case:
assert case_id in self.cases, f"Case {case_id} out of possible range."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some of these asserts should be raises instead (maybe ValueError?), since they occur through use of the API beyond our control.

input_types = self.parent_op.nth_inputs(case_id)
new_case = Case.new_nested(
ops.Case(input_types),
self.hugr,
self.parent_node,
)
new_case._parent_cond = self
self.cases[case_id] = new_case.parent_node
return new_case

# TODO insert_case
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO now or later?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

later - I'm not convinced it's a common enough use case to do immediately



@dataclass
class TailLoop(_DfBase[ops.TailLoop]):
def set_loop_outputs(self, sum_wire: Wire, *rest: Wire) -> None:
self.set_outputs(sum_wire, *rest)
77 changes: 57 additions & 20 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@
from typing import (
TYPE_CHECKING,
Iterable,
Sequence,
TypeVar,
)

from typing_extensions import Self

import hugr._ops as ops
import hugr._val as val
from hugr._tys import Type, TypeRow
from hugr._tys import Type, TypeRow, get_first_sum

from ._exceptions import NoSiblingAncestor
from ._hugr import Hugr, Node, OutPort, ParentBuilder, ToNode, Wire

if TYPE_CHECKING:
from ._cfg import Cfg
from ._cond_loop import Conditional, If, TailLoop


DP = TypeVar("DP", bound=ops.DfParentOp)
Expand Down Expand Up @@ -72,39 +74,74 @@ def add_op(self, op: ops.DataflowOp, /, *args: Wire) -> Node:
def add(self, com: ops.Command) -> Node:
return self.add_op(com.op, *com.incoming)

def _insert_nested_impl(self, builder: ParentBuilder, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(builder.hugr, self.parent_node)
self._wire_up(mapping[builder.parent_node], args)
return mapping[builder.parent_node]

def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(dfg.hugr, self.parent_node)
self._wire_up(mapping[dfg.parent_node], args)
return mapping[dfg.parent_node]
return self._insert_nested_impl(dfg, *args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this delegated to an internal method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the internal method is generic, the surface ones are concrete for cases we know
about. Allows special casing in future if necessary.


def add_nested(
self,
*args: Wire,
) -> Dfg:
from ._dfg import Dfg

input_types = [self._get_dataflow_type(w) for w in args]

parent_op = ops.DFG(list(input_types))
parent_op = ops.DFG(self._wire_types(args))
dfg = Dfg.new_nested(parent_op, self.hugr, self.parent_node)
self._wire_up(dfg.parent_node, args)
return dfg

def _wire_types(self, args: Iterable[Wire]) -> TypeRow:
return [self._get_dataflow_type(w) for w in args]

def add_cfg(
self,
input_types: TypeRow,
*args: Wire,
) -> Cfg:
from ._cfg import Cfg

cfg = Cfg.new_nested(input_types, self.hugr, self.parent_node)
cfg = Cfg.new_nested(self._wire_types(args), self.hugr, self.parent_node)
self._wire_up(cfg.parent_node, args)
return cfg

def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(cfg.hugr, self.parent_node)
self._wire_up(mapping[cfg.parent_node], args)
return mapping[cfg.parent_node]
return self._insert_nested_impl(cfg, *args)

def add_conditional(self, cond: Wire, *args: Wire) -> Conditional:
from ._cond_loop import Conditional

args = (cond, *args)
(sum_, other_inputs) = get_first_sum(self._wire_types(args))
cond = Conditional.new_nested(sum_, other_inputs, self.hugr, self.parent_node)
self._wire_up(cond.parent_node, args)
return cond

def insert_conditional(self, cond: Conditional, *args: Wire) -> Node:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test?

return self._insert_nested_impl(cond, *args)

def add_if(self, cond: Wire, *args: Wire) -> If:
from ._cond_loop import If

conditional = self.add_conditional(cond, *args)
return If(conditional.add_case(1))

def add_tail_loop(
self, just_inputs: Sequence[Wire], rest: Sequence[Wire]
) -> TailLoop:
from ._cond_loop import TailLoop

rest = rest or []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vestigial, will remove

just_input_types = self._wire_types(just_inputs)
rest_types = self._wire_types(rest)
parent_op = ops.TailLoop(just_input_types, rest_types)
tl = TailLoop.new_nested(parent_op, self.hugr, self.parent_node)
self._wire_up(tl.parent_node, (*just_inputs, *rest))
return tl

def insert_tail_loop(self, tl: TailLoop, *args: Wire) -> Node:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test?

return self._insert_nested_impl(tl, *args)

def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)
Expand All @@ -117,22 +154,22 @@ def add_state_order(self, src: Node, dst: Node) -> None:
def add_const(self, val: val.Value) -> Node:
return self.hugr.add_const(val, self.parent_node)

def load_const(self, const_node: ToNode) -> Node:
const_op = self.hugr._get_typed_op(const_node, ops.Const)
def load(self, const: ToNode | val.Value) -> Node:
if isinstance(const, val.Value):
const = self.add_const(const)
const_op = self.hugr._get_typed_op(const, ops.Const)
load_op = ops.LoadConst(const_op.val.type_())

load = self.add(load_op())
self.hugr.add_link(const_node.out_port(), load.inp(0))
self.hugr.add_link(const.out_port(), load.inp(0))

return load

def add_load_const(self, val: val.Value) -> Node:
return self.load_const(self.add_const(val))

def _wire_up(self, node: Node, ports: Iterable[Wire]):
def _wire_up(self, node: Node, ports: Iterable[Wire]) -> TypeRow:
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops.PartialOp):
op.set_in_types(tys)
return tys

def _get_dataflow_type(self, wire: Wire) -> Type:
port = wire.out_port()
Expand All @@ -141,7 +178,7 @@ def _get_dataflow_type(self, wire: Wire) -> Type:
raise ValueError(f"Port {port} is not a dataflow port.")
return ty

def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:
def _wire_up_port(self, node: Node, offset: int, p: Wire):
src = p.out_port()
node_ancestor = _ancestral_sibling(self.hugr, src.node, node)
if node_ancestor is None:
Expand Down
Loading
Loading