Skip to content

Commit

Permalink
Merge branch 'main' into ab/json-helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q authored Aug 12, 2024
2 parents dddf813 + aa81c9a commit e322a60
Show file tree
Hide file tree
Showing 16 changed files with 179 additions and 31 deletions.
4 changes: 2 additions & 2 deletions hugr-core/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -777,8 +777,8 @@ fn wire_up<T: Dataflow + ?Sized>(
});
};

if !OpTag::BasicBlock.is_superset(base.get_optype(src).tag())
&& !OpTag::BasicBlock.is_superset(base.get_optype(src_sibling).tag())
if !OpTag::ControlFlowChild.is_superset(base.get_optype(src).tag())
&& !OpTag::ControlFlowChild.is_superset(base.get_optype(src_sibling).tag())
{
// Add a state order constraint unless one of the nodes is a CFG BasicBlock
base.add_other_edge(src, src_sibling);
Expand Down
2 changes: 1 addition & 1 deletion hugr-core/src/hugr/views/root_checked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ mod test {
r,
Err(HugrError::InvalidTag {
required: OpTag::Dfg,
actual: ops::OpTag::BasicBlock
actual: ops::OpTag::DataflowBlock
})
);
// That didn't do anything:
Expand Down
2 changes: 1 addition & 1 deletion hugr-core/src/ops/controlflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ impl NamedOp for ExitBlock {
}

impl StaticTag for DataflowBlock {
const TAG: OpTag = OpTag::BasicBlock;
const TAG: OpTag = OpTag::DataflowBlock;
}

impl StaticTag for ExitBlock {
Expand Down
2 changes: 1 addition & 1 deletion hugr-core/src/ops/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ impl_nodehandle!(ModuleRootID, OpTag::ModuleRoot);
impl_nodehandle!(ModuleID, OpTag::ModuleOp);
impl_nodehandle!(ConstID, OpTag::Const);

impl_nodehandle!(BasicBlockID, OpTag::BasicBlock);
impl_nodehandle!(BasicBlockID, OpTag::DataflowBlock);

impl<const DEF: bool> NodeHandle for FuncID<DEF> {
const TAG: OpTag = OpTag::Function;
Expand Down
20 changes: 12 additions & 8 deletions hugr-core/src/ops/tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ pub enum OpTag {
/// A leaf operation.
Leaf,

/// A control flow basic block.
BasicBlock,
/// A control flow basic block defining a dataflow graph.
DataflowBlock,
/// A control flow exit node.
BasicBlockExit,
}
Expand Down Expand Up @@ -113,8 +113,8 @@ impl OpTag {
OpTag::Function => &[OpTag::ModuleOp, OpTag::StaticOutput],
OpTag::Alias => &[OpTag::ScopedDefn],
OpTag::FuncDefn => &[OpTag::Function, OpTag::ScopedDefn, OpTag::DataflowParent],
OpTag::BasicBlock => &[OpTag::ControlFlowChild, OpTag::DataflowParent],
OpTag::BasicBlockExit => &[OpTag::BasicBlock],
OpTag::DataflowBlock => &[OpTag::ControlFlowChild, OpTag::DataflowParent],
OpTag::BasicBlockExit => &[OpTag::ControlFlowChild],
OpTag::Case => &[OpTag::Any, OpTag::DataflowParent],
OpTag::ModuleRoot => &[OpTag::Any],
OpTag::Const => &[OpTag::ScopedDefn, OpTag::StaticOutput],
Expand Down Expand Up @@ -148,7 +148,7 @@ impl OpTag {
OpTag::Input => "Input node",
OpTag::Output => "Output node",
OpTag::FuncDefn => "Function definition",
OpTag::BasicBlock => "Basic block",
OpTag::DataflowBlock => "Basic block containing a dataflow graph",
OpTag::BasicBlockExit => "Exit basic block node",
OpTag::Case => "Case",
OpTag::ModuleRoot => "Module root node",
Expand Down Expand Up @@ -213,16 +213,20 @@ mod test {
assert!(OpTag::None.is_superset(OpTag::None));
assert!(OpTag::ModuleOp.is_superset(OpTag::ModuleOp));
assert!(OpTag::DataflowChild.is_superset(OpTag::DataflowChild));
assert!(OpTag::BasicBlock.is_superset(OpTag::BasicBlock));
assert!(OpTag::ControlFlowChild.is_superset(OpTag::ControlFlowChild));

assert!(OpTag::Any.is_superset(OpTag::None));
assert!(OpTag::Any.is_superset(OpTag::ModuleOp));
assert!(OpTag::Any.is_superset(OpTag::DataflowChild));
assert!(OpTag::Any.is_superset(OpTag::BasicBlock));
assert!(OpTag::Any.is_superset(OpTag::ControlFlowChild));

assert!(!OpTag::None.is_superset(OpTag::Any));
assert!(!OpTag::None.is_superset(OpTag::ModuleOp));
assert!(!OpTag::None.is_superset(OpTag::DataflowChild));
assert!(!OpTag::None.is_superset(OpTag::BasicBlock));
assert!(!OpTag::None.is_superset(OpTag::ControlFlowChild));

// Other specific checks
assert!(!OpTag::DataflowParent.is_superset(OpTag::BasicBlockExit));
assert!(!OpTag::DataflowParent.is_superset(OpTag::Cfg));
}
}
2 changes: 1 addition & 1 deletion hugr-core/src/ops/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl ValidateOp for super::CFG {
fn validity_flags(&self) -> OpValidityFlags {
OpValidityFlags {
allowed_children: OpTag::ControlFlowChild,
allowed_first_child: OpTag::BasicBlock,
allowed_first_child: OpTag::DataflowBlock,
allowed_second_child: OpTag::BasicBlockExit,
requires_children: true,
requires_dag: false,
Expand Down
17 changes: 15 additions & 2 deletions hugr-core/src/std_extensions/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ impl ConstFold for NaryLogic {
(res || inps.len() as u64 == num_args)
.then_some(vec![(0.into(), ops::Value::from_bool(res))])
}
Self::Eq => {
let inps = read_inputs(consts)?;
let res = inps.iter().copied().reduce(|a, b| a == b)?;
// If we have only some inputs, we can still fold to false, but not to true
(!res || inps.len() as u64 == num_args)
.then_some(vec![(0.into(), ops::Value::from_bool(res))])
}
}
}
}
Expand All @@ -57,6 +64,7 @@ impl ConstFold for NaryLogic {
pub enum NaryLogic {
And,
Or,
Eq,
}

impl MakeOpDef for NaryLogic {
Expand All @@ -68,6 +76,7 @@ impl MakeOpDef for NaryLogic {
match self {
NaryLogic::And => "logical 'and'",
NaryLogic::Or => "logical 'or'",
NaryLogic::Eq => "test if bools are equal",
}
.to_string()
}
Expand Down Expand Up @@ -275,7 +284,7 @@ pub(crate) mod test {
fn test_logic_extension() {
let r: Extension = extension();
assert_eq!(r.name() as &str, "logic");
assert_eq!(r.operations().count(), 3);
assert_eq!(r.operations().count(), 4);

for op in NaryLogic::iter() {
assert_eq!(
Expand All @@ -287,7 +296,7 @@ pub(crate) mod test {

#[test]
fn test_conversions() {
for def in [NaryLogic::And, NaryLogic::Or] {
for def in [NaryLogic::And, NaryLogic::Or, NaryLogic::Eq] {
let o = def.with_n_inputs(3);
let ext_op = o.clone().to_extension_op().unwrap();
let custom_op: CustomOp = ext_op.into();
Expand Down Expand Up @@ -331,6 +340,8 @@ pub(crate) mod test {
#[case(NaryLogic::Or, [], false)]
#[case(NaryLogic::Or, [false, false, true], true)]
#[case(NaryLogic::Or, [false, false, false], false)]
#[case(NaryLogic::Eq, [true, true, false, true], false)]
#[case(NaryLogic::Eq, [false, false], true)]
fn nary_const_fold(
#[case] op: NaryLogic,
#[case] ins: impl IntoIterator<Item = bool>,
Expand All @@ -355,6 +366,8 @@ pub(crate) mod test {
#[case(NaryLogic::And, [Some(false), None], Some(false))]
#[case(NaryLogic::Or, [None, Some(false)], None)]
#[case(NaryLogic::Or, [None, Some(true)], Some(true))]
#[case(NaryLogic::Eq, [None, Some(true), Some(true)], None)]
#[case(NaryLogic::Eq, [None, Some(false), Some(true)], Some(false))]
fn nary_partial_const_fold(
#[case] op: NaryLogic,
#[case] ins: impl IntoIterator<Item = Option<bool>>,
Expand Down
1 change: 1 addition & 0 deletions hugr-py/docs/api-docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"image_light": "_static/Quantinuum_logo_black.png",
"image_dark": "_static/Quantinuum_logo_white.png",
},
"show_navbar_depth": 2,
}

html_static_path = ["../_static"]
Expand Down
14 changes: 13 additions & 1 deletion hugr-py/src/hugr/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from typing_extensions import Self

from hugr import ops, val
from hugr import ops, tys, val

from .dfg import _DfBase
from .exceptions import MismatchedExit, NoSiblingAncestor, NotInSameCfg
Expand All @@ -22,6 +22,15 @@
class Block(_DfBase[ops.DataflowBlock]):
"""Builder class for a basic block in a HUGR control flow graph."""

def set_outputs(self, *outputs: Wire) -> None:
super().set_outputs(*outputs)

assert len(outputs) > 0
branching = outputs[0]
branch_type = self.hugr.port_type(branching.out_port())
assert isinstance(branch_type, tys.Sum)
self._set_parent_output_count(len(branch_type.variant_rows))

def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None:
self.set_outputs(branching, *other_outputs)

Expand Down Expand Up @@ -249,3 +258,6 @@ def branch_exit(self, src: Wire) -> None:
else:
self._exit_op._cfg_outputs = out_types
self.parent_op._outputs = out_types
self.parent_node = self.hugr._update_node_outs(
self.parent_node, len(out_types)
)
13 changes: 12 additions & 1 deletion hugr-py/src/hugr/cond_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
from typing_extensions import Self

from hugr import ops
from hugr.tys import Sum

from .dfg import _DfBase
from .hugr import Hugr, ParentBuilder

if TYPE_CHECKING:
from .node_port import Node, ToNode, Wire
from .tys import Sum, TypeRow
from .tys import TypeRow


class Case(_DfBase[ops.Case]):
Expand Down Expand Up @@ -215,6 +216,16 @@ def __init__(self, just_inputs: TypeRow, rest: TypeRow) -> None:
root_op = ops.TailLoop(just_inputs, rest)
super().__init__(root_op)

def set_outputs(self, *outputs: Wire) -> None:
super().set_outputs(*outputs)

assert len(outputs) > 0
sum_wire = outputs[0]
sum_type = self.hugr.port_type(sum_wire.out_port())
assert isinstance(sum_type, Sum)
assert len(sum_type.variant_rows) == 2
self._set_parent_output_count(len(sum_type.variant_rows[1]) + len(outputs) - 1)

def set_loop_outputs(self, sum_wire: Wire, *rest: Wire) -> None:
"""Set the outputs of the loop body. The first wire must be the sum type
that controls loop termination.
Expand Down
45 changes: 36 additions & 9 deletions hugr-py/src/hugr/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,29 @@ def define_function(
name: str,
input_types: TypeRow,
type_params: list[TypeParam] | None = None,
parent: ToNode | None = None,
) -> Function:
"""Start building a function definition in the graph.
Args:
name: The name of the function.
input_types: The input types for the function.
type_params: The type parameters for the function, if polymorphic.
parent: The parent node of the constant. Defaults to the root node.
Returns:
The new function builder.
"""
parent_node = parent or self.hugr.root
parent_op = ops.FuncDefn(name, input_types, type_params or [])
return Function.new_nested(parent_op, self.hugr)
return Function.new_nested(parent_op, self.hugr, parent_node)

def add_const(self, value: val.Value) -> Node:
def add_const(self, value: val.Value, parent: ToNode | None = None) -> Node:
"""Add a static constant to the graph.
Args:
value: The constant value to add.
parent: The parent node of the constant. Defaults to the root node.
Returns:
The node holding the :class:`Const <hugr.ops.Const>` operation.
Expand All @@ -71,11 +75,13 @@ def add_const(self, value: val.Value) -> Node:
>>> dfg.hugr[const_n].op
Const(TRUE)
"""
return self.hugr.add_node(ops.Const(value), self.hugr.root)
parent_node = parent or self.hugr.root
return self.hugr.add_node(ops.Const(value), parent_node)

def add_alias_defn(self, name: str, ty: Type) -> Node:
def add_alias_defn(self, name: str, ty: Type, parent: ToNode | None = None) -> Node:
"""Add a type alias definition."""
return self.hugr.add_node(ops.AliasDefn(name, ty), self.hugr.root)
parent_node = parent or self.hugr.root
return self.hugr.add_node(ops.AliasDefn(name, ty), parent_node)


DP = TypeVar("DP", bound=ops.DfParentOp)
Expand Down Expand Up @@ -466,6 +472,18 @@ def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)
self.parent_op._set_out_types(self._output_op().types)

def _set_parent_output_count(self, count: int) -> None:
"""Set the final number of output ports on the parent operation.
Args:
count: The number of output ports.
Example:
>>> dfg = Dfg(tys.Bool)
>>> dfg._set_parent_output_count(2)
"""
self.parent_node = self.hugr._update_node_outs(self.parent_node, count)

def add_state_order(self, src: Node, dst: Node) -> None:
"""Add a state order link between two nodes.
Expand All @@ -482,13 +500,17 @@ def add_state_order(self, src: Node, dst: Node) -> None:
# adds edge to the right of all existing edges
self.hugr.add_link(src.out(-1), dst.inp(-1))

def load(self, const: ToNode | val.Value) -> Node:
def load(
self, const: ToNode | val.Value, const_parent: ToNode | None = None
) -> Node:
"""Load a constant into the graph as a dataflow value.
Args:
const: The constant to load, either a Value that will be added as a
child Const node then loaded, or a node corresponding to an existing
Const.
child Const node then loaded, or a node corresponding to an
existing Const.
const_parent: If `const` is a Value, the parent node for the new
constant definition. Defaults to the current dataflow container.
Returns:
The node holding the :class:`LoadConst <hugr.ops.LoadConst>`
Expand All @@ -503,7 +525,8 @@ def load(self, const: ToNode | val.Value) -> Node:
LoadConst(Bool)
"""
if isinstance(const, val.Value):
const = self.add_const(const)
const_parent = const_parent or self.parent_node
const = self.add_const(const, parent=const_parent)
const_op = self.hugr._get_typed_op(const, ops.Const)
load_op = ops.LoadConst(const_op.val.type_())

Expand Down Expand Up @@ -620,6 +643,10 @@ def __init__(self, *input_types: tys.Type) -> None:
parent_op = ops.DFG(list(input_types), None)
super().__init__(parent_op)

def set_outputs(self, *outputs: Wire) -> None:
super().set_outputs(*outputs)
self._set_parent_output_count(len(outputs))


def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None:
"""Find the ancestor of `tgt` that is a sibling of `src`, if one exists."""
Expand Down
Loading

0 comments on commit e322a60

Please sign in to comment.