Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Make some Container methods infallible #872

Merged
merged 3 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions quantinuum-hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ fn const_graph(consts: Vec<Const>, reg: &ExtensionRegistry) -> Hugr {

let outputs = consts
.into_iter()
.map(|c| b.add_load_const(c).unwrap())
.map(|c| b.add_load_const(c))
.collect_vec();

b.finish_hugr_with_outputs(outputs, reg).unwrap()
Expand Down Expand Up @@ -265,9 +265,7 @@ mod test {
let mut build =
DFGBuilder::new(FunctionType::new(type_row![], vec![sum_type.clone()])).unwrap();

let tup = build
.add_load_const(Const::new_tuple([f2c(5.6), f2c(3.2)]))
.unwrap();
let tup = build.add_load_const(Const::new_tuple([f2c(5.6), f2c(3.2)]));

let unpack = build
.add_dataflow_op(
Expand Down Expand Up @@ -320,7 +318,7 @@ mod test {
) -> Result<(), Box<dyn std::error::Error>> {
let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap();

let ins = ins.map(|b| build.add_load_const(Const::from_bool(b)).unwrap());
let ins = ins.map(|b| build.add_load_const(Const::from_bool(b)));
let logic_op = build.add_dataflow_op(op.with_n_inputs(ins.len() as u64), ins)?;

let reg =
Expand Down Expand Up @@ -350,7 +348,7 @@ mod test {
))
.unwrap();

let list_wire = build.add_load_const(list.clone())?;
let list_wire = build.add_load_const(list.clone());

let pop = build.add_dataflow_op(
ListOp::Pop.with_type(BOOL_T).to_extension_op(&reg).unwrap(),
Expand Down
14 changes: 7 additions & 7 deletions quantinuum-hugr/src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,8 @@ pub(crate) mod test {
// \-> right -/ \-<--<-/
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;

let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2));
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum());

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?,
Expand Down Expand Up @@ -813,7 +813,7 @@ pub(crate) mod test {
pred_const: &ConstID,
) -> Result<T::ContainerHandle, BuildError> {
let w = dataflow_builder.input_wires();
let u = dataflow_builder.load_const(pred_const)?;
let u = dataflow_builder.load_const(pred_const);
dataflow_builder.finish_with_outputs([u].into_iter().chain(w))
}

Expand Down Expand Up @@ -887,8 +887,8 @@ pub(crate) mod test {
separate: bool,
) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> {
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2));
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum());

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 2, ExtensionSet::new())?,
Expand Down Expand Up @@ -929,8 +929,8 @@ pub(crate) mod test {
cfg_builder: &mut CFGBuilder<T>,
separate_headers: bool,
) -> Result<(BasicBlockID, BasicBlockID), BuildError> {
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2));
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum());

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?,
Expand Down
80 changes: 34 additions & 46 deletions quantinuum-hugr/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,23 @@ pub trait Container {
/// Immutable reference to HUGR being built
fn hugr(&self) -> &Hugr;
/// Add an [`OpType`] as the final child of the container.
fn add_child_op(&mut self, op: impl Into<OpType>) -> Result<Node, BuildError> {
fn add_child_op(&mut self, op: impl Into<OpType>) -> Node {
let parent = self.container_node();
Ok(self.hugr_mut().add_node_with_parent(parent, op))
self.hugr_mut().add_node_with_parent(parent, op)
}
/// Add a [`NodeType`] as the final child of the container.
fn add_child_node(&mut self, node: NodeType) -> Result<Node, BuildError> {
fn add_child_node(&mut self, node: NodeType) -> Node {
let parent = self.container_node();
Ok(self.hugr_mut().add_node_with_parent(parent, node))
self.hugr_mut().add_node_with_parent(parent, node)
}

/// Adds a non-dataflow edge between two nodes. The kind is given by the operation's [`other_inputs`] or [`other_outputs`]
///
/// [`other_inputs`]: crate::ops::OpTrait::other_input
/// [`other_outputs`]: crate::ops::OpTrait::other_output
fn add_other_wire(&mut self, src: Node, dst: Node) -> Result<Wire, BuildError> {
fn add_other_wire(&mut self, src: Node, dst: Node) -> Wire {
let (src_port, _) = self.hugr_mut().add_other_edge(src, dst);
Ok(Wire::new(src, src_port))
Wire::new(src, src_port)
}

/// Add a constant value to the container and return a handle to it.
Expand All @@ -71,10 +71,9 @@ pub trait Container {
///
/// This function will return an error if there is an error in adding the
/// [`OpType::Const`] node.
fn add_constant(&mut self, constant: impl Into<ops::Const>) -> Result<ConstID, BuildError> {
let const_n = self.add_child_node(NodeType::new_pure(constant.into()))?;

Ok(const_n.into())
fn add_constant(&mut self, constant: impl Into<ops::Const>) -> ConstID {
self.add_child_node(NodeType::new_pure(constant.into()))
.into()
}

/// Add a [`ops::FuncDefn`] node and returns a builder to define the function
Expand All @@ -93,23 +92,23 @@ pub trait Container {
let f_node = self.add_child_node(NodeType::new_pure(ops::FuncDefn {
name: name.into(),
signature,
}))?;
}));

let db =
DFGBuilder::create_with_io(self.hugr_mut(), f_node, body, Some(ExtensionSet::new()))?;
Ok(FunctionBuilder::from_dfg_builder(db))
}

/// Insert a HUGR as a child of the container.
fn add_hugr(&mut self, child: Hugr) -> Result<InsertionResult, BuildError> {
fn add_hugr(&mut self, child: Hugr) -> InsertionResult {
let parent = self.container_node();
Ok(self.hugr_mut().insert_hugr(parent, child))
self.hugr_mut().insert_hugr(parent, child)
}

/// Insert a copy of a HUGR as a child of the container.
fn add_hugr_view(&mut self, child: &impl HugrView) -> Result<InsertionResult, BuildError> {
fn add_hugr_view(&mut self, child: &impl HugrView) -> InsertionResult {
let parent = self.container_node();
Ok(self.hugr_mut().insert_from_view(parent, child))
self.hugr_mut().insert_from_view(parent, child)
}

/// Add metadata to the container node.
Expand All @@ -127,9 +126,8 @@ pub trait Container {
child: Node,
key: impl AsRef<str>,
meta: impl Into<NodeMetadata>,
) -> Result<(), BuildError> {
) {
self.hugr_mut().set_metadata(child, key, meta);
Ok(())
}
}

Expand Down Expand Up @@ -228,7 +226,7 @@ pub trait Dataflow: Container {
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
let num_outputs = hugr.get_optype(hugr.root()).value_output_count();
let node = self.add_hugr(hugr)?.new_root;
let node = self.add_hugr(hugr).new_root;

let inputs = input_wires.into_iter().collect();
wire_up_inputs(inputs, node, self)?;
Expand All @@ -248,7 +246,7 @@ pub trait Dataflow: Container {
hugr: &impl HugrView,
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
let node = self.add_hugr_view(hugr)?.new_root;
let node = self.add_hugr_view(hugr).new_root;
let num_outputs = hugr.get_optype(hugr.root()).value_output_count();

let inputs = input_wires.into_iter().collect();
Expand Down Expand Up @@ -341,10 +339,7 @@ pub trait Dataflow: Container {

/// Load a static constant and return the local dataflow wire for that constant.
/// Adds a [`OpType::LoadConstant`] node.
/// # Errors
///
/// This function will return an error if there is an error when adding the node.
fn load_const(&mut self, cid: &ConstID) -> Result<Wire, BuildError> {
fn load_const(&mut self, cid: &ConstID) -> Wire {
let const_node = cid.node();
let nodetype = self.hugr().get_nodetype(const_node);
let op: ops::Const = nodetype
Expand All @@ -353,24 +348,23 @@ pub trait Dataflow: Container {
.try_into()
.expect("ConstID does not refer to Const op.");

let load_n = self.add_dataflow_op(
ops::LoadConstant {
datatype: op.const_type().clone(),
},
// Constant wire from the constant value node
vec![Wire::new(const_node, OutgoingPort::from(0))],
)?;
let load_n = self
.add_dataflow_op(
ops::LoadConstant {
datatype: op.const_type().clone(),
},
// Constant wire from the constant value node
vec![Wire::new(const_node, OutgoingPort::from(0))],
)
.expect("The constant type should match the LoadConstant type.");

Ok(load_n.out_wire(0))
load_n.out_wire(0)
}

/// Load a static constant and return the local dataflow wire for that constant.
/// Adds a [`ops::LoadConstant`] node.
/// # Errors
///
/// This function will return an error if there is an error when adding the node.
fn add_load_const(&mut self, constant: impl Into<ops::Const>) -> Result<Wire, BuildError> {
let cid = self.add_constant(constant)?;
/// Adds a [`ops::Const`] and a [`ops::LoadConstant`] node.
fn add_load_const(&mut self, constant: impl Into<ops::Const>) -> Wire {
let cid = self.add_constant(constant);
self.load_const(&cid)
}

Expand Down Expand Up @@ -455,14 +449,8 @@ pub trait Dataflow: Container {

/// Add an order edge from `before` to `after`. Assumes any additional edges
/// to both nodes will be Order kind.
fn set_order(
&mut self,
before: &impl NodeHandle,
after: &impl NodeHandle,
) -> Result<(), BuildError> {
self.add_other_wire(before.node(), after.node())?;

Ok(())
fn set_order(&mut self, before: &impl NodeHandle, after: &impl NodeHandle) {
self.add_other_wire(before.node(), after.node());
}

/// Get the type of a Value [`Wire`]. If not valid port or of Value kind, returns None.
Expand Down Expand Up @@ -620,7 +608,7 @@ fn add_node_with_wires<T: Dataflow + ?Sized>(
) -> Result<(Node, usize), BuildError> {
let nodetype: NodeType = nodetype.into();
let num_outputs = nodetype.op().value_output_count();
let op_node = data_builder.add_child_node(nodetype)?;
let op_node = data_builder.add_child_node(nodetype);

wire_up_inputs(inputs, op_node, data_builder)?;

Expand Down
20 changes: 10 additions & 10 deletions quantinuum-hugr/src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ use crate::{
/// ops::Const::tuple_sum(0,
/// Value::tuple([prelude::ConstUsize::new(42).into()]),
/// sum_variants.clone())?;
/// let sum = entry_b.add_load_const(left_42)?;
/// let sum = entry_b.add_load_const(left_42);
///
/// entry_b.finish_with_outputs(sum, [inw])?
/// };
Expand All @@ -92,7 +92,7 @@ use crate::{
/// )?;
/// let successor_a = {
/// // This block has one successor. The choice is denoted by a unary sum.
/// let sum_unary = successor_builder.add_load_const(ops::Const::unary_unit_sum())?;
/// let sum_unary = successor_builder.add_load_const(ops::Const::unary_unit_sum());
///
/// // The input wires of a node start with the data embedded in the variant
/// // which selected this block.
Expand All @@ -104,7 +104,7 @@ use crate::{
/// let mut successor_builder =
/// cfg_builder.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
/// let successor_b = {
/// let sum_unary = successor_builder.add_load_const(ops::Const::unary_unit_sum())?;
/// let sum_unary = successor_builder.add_load_const(ops::Const::unary_unit_sum());
/// let [in_wire] = successor_builder.input_wires_arr();
/// successor_builder.finish_with_outputs(sum_unary, [in_wire])?
/// };
Expand Down Expand Up @@ -469,7 +469,7 @@ pub(crate) mod test {
let mut middle_b = cfg_builder
.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
let middle = {
let c = middle_b.add_load_const(ops::Const::unary_unit_sum())?;
let c = middle_b.add_load_const(ops::Const::unary_unit_sum());
let [inw] = middle_b.input_wires_arr();
middle_b.finish_with_outputs(c, [inw])?
};
Expand All @@ -482,21 +482,21 @@ pub(crate) mod test {
#[test]
fn test_dom_edge() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let sum_tuple_const = cfg_builder.add_constant(ops::Const::unary_unit_sum())?;
let sum_tuple_const = cfg_builder.add_constant(ops::Const::unary_unit_sum());
let sum_variants = vec![type_row![]];

let mut entry_b =
cfg_builder.entry_builder(sum_variants.clone(), type_row![], ExtensionSet::new())?;
let [inw] = entry_b.input_wires_arr();
let entry = {
let sum = entry_b.load_const(&sum_tuple_const)?;
let sum = entry_b.load_const(&sum_tuple_const);

entry_b.finish_with_outputs(sum, [])?
};
let mut middle_b =
cfg_builder.simple_block_builder(FunctionType::new(type_row![], type_row![NAT]), 1)?;
let middle = {
let c = middle_b.load_const(&sum_tuple_const)?;
let c = middle_b.load_const(&sum_tuple_const);
middle_b.finish_with_outputs(c, [inw])?
};
let exit = cfg_builder.exit_block();
Expand All @@ -510,20 +510,20 @@ pub(crate) mod test {
#[test]
fn test_non_dom_edge() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let sum_tuple_const = cfg_builder.add_constant(ops::Const::unary_unit_sum())?;
let sum_tuple_const = cfg_builder.add_constant(ops::Const::unary_unit_sum());
let sum_variants = vec![type_row![]];
let mut middle_b = cfg_builder
.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
let [inw] = middle_b.input_wires_arr();
let middle = {
let c = middle_b.load_const(&sum_tuple_const)?;
let c = middle_b.load_const(&sum_tuple_const);
middle_b.finish_with_outputs(c, [inw])?
};

let mut entry_b =
cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT], ExtensionSet::new())?;
let entry = {
let sum = entry_b.load_const(&sum_tuple_const)?;
let sum = entry_b.load_const(&sum_tuple_const);
// entry block uses wire from middle block even though middle block
// does not dominate entry
entry_b.finish_with_outputs(sum, [inw])?
Expand Down
6 changes: 3 additions & 3 deletions quantinuum-hugr/src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
if let Some(&sibling_node) = self.case_nodes[case + 1..].iter().flatten().next() {
self.hugr_mut().add_node_before(sibling_node, case_op)
} else {
self.add_child_op(case_op)?
self.add_child_op(case_op)
};

self.case_nodes[case] = Some(case_node);
Expand Down Expand Up @@ -242,9 +242,9 @@ mod test {
"main",
FunctionType::new(type_row![NAT], type_row![NAT]).into(),
)?;
let tru_const = fbuild.add_constant(Const::true_val())?;
let tru_const = fbuild.add_constant(Const::true_val());
let _fdef = {
let const_wire = fbuild.load_const(&tru_const)?;
let const_wire = fbuild.load_const(&tru_const);
let [int] = fbuild.input_wires_arr();
let conditional_id = {
let other_inputs = vec![(NAT, int)];
Expand Down
2 changes: 1 addition & 1 deletion quantinuum-hugr/src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ pub(crate) mod test {
let [i1] = f_build.input_wires_arr();
let dfg = f_build.add_hugr_with_wires(dfg_hugr, [i1])?;
let f = f_build.finish_with_outputs([dfg.out_wire(0)])?;
module_builder.set_child_metadata(f.node(), "x", "hi")?;
module_builder.set_child_metadata(f.node(), "x", "hi");
(dfg.node(), f.node())
};

Expand Down
6 changes: 3 additions & 3 deletions quantinuum-hugr/src/builder/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
let declare_n = self.add_child_node(NodeType::new_pure(ops::FuncDecl {
signature,
name: name.into(),
}))?;
}));

Ok(declare_n.into())
}
Expand All @@ -136,7 +136,7 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
let node = self.add_child_op(ops::AliasDefn {
name: name.clone(),
definition: typ,
})?;
});

Ok(AliasID::new(node, name, bound))
}
Expand All @@ -154,7 +154,7 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
let node = self.add_child_op(ops::AliasDecl {
name: name.clone(),
bound,
})?;
});

Ok(AliasID::new(node, name, bound))
}
Expand Down
Loading