Skip to content

Commit

Permalink
feat!: Remove ExtensionRegistry args in UnwrapBuilder and ListOp (#1785)
Browse files Browse the repository at this point in the history
Last two remaining usages of `ExtensionRegistry` as redundant
parameters. These should have been included in #1784.

BREAKING CHANGE: `UnwrapBuilder` no longer requires an
`ExtensionRegistry` as argument.
BREAKING CHANGE: `ListOpInst` no longer requires an `ExtensionRegistry`
to produce an `ExtensionOp`.
  • Loading branch information
aborgna-q authored Dec 16, 2024
1 parent b517dc3 commit 7cf7bb6
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 89 deletions.
21 changes: 7 additions & 14 deletions hugr-core/src/extension/prelude/unwrap_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,20 @@ use std::iter;

use crate::{
builder::{BuildError, BuildHandle, Dataflow, DataflowSubContainer, SubContainer},
extension::{
prelude::{ConstError, PANIC_OP_ID, PRELUDE_ID},
ExtensionRegistry,
},
extension::prelude::{ConstError, PANIC_OP_ID},
ops::handle::DataflowOpID,
types::{SumType, Type, TypeArg, TypeRow},
Wire,
};
use itertools::{zip_eq, Itertools as _};

use super::PRELUDE;

/// Extend dataflow builders with methods for building unwrap operations.
pub trait UnwrapBuilder: Dataflow {
/// Add a panic operation to the dataflow with the given error.
fn add_panic(
&mut self,
reg: &ExtensionRegistry,
err: ConstError,
output_row: impl IntoIterator<Item = Type>,
inputs: impl IntoIterator<Item = (Wire, Type)>,
Expand All @@ -33,8 +31,7 @@ pub trait UnwrapBuilder: Dataflow {
.map(<TypeArg as From<_>>::from)
.collect_vec()
.into();
let prelude = reg.get(&PRELUDE_ID).unwrap();
let op = prelude.instantiate_extension_op(&PANIC_OP_ID, [input_arg, output_arg])?;
let op = PRELUDE.instantiate_extension_op(&PANIC_OP_ID, [input_arg, output_arg])?;
let err = self.add_load_value(err);
self.add_dataflow_op(op, iter::once(err).chain(input_wires))
}
Expand All @@ -43,7 +40,6 @@ pub trait UnwrapBuilder: Dataflow {
/// or panic if the tag is not the expected value.
fn build_unwrap_sum<const N: usize>(
&mut self,
reg: &ExtensionRegistry,
tag: usize,
sum_type: SumType,
input: Wire,
Expand All @@ -70,7 +66,7 @@ pub trait UnwrapBuilder: Dataflow {
let inputs = zip_eq(case.input_wires(), variant.iter().cloned());
let err =
ConstError::new(1, format!("Expected variant {} but got variant {}", tag, i));
let outputs = case.add_panic(reg, err, output_row, inputs)?.outputs();
let outputs = case.add_panic(err, output_row, inputs)?.outputs();
case.finish_with_outputs(outputs)?;
}
}
Expand All @@ -85,10 +81,7 @@ mod tests {
use super::*;
use crate::{
builder::{DFGBuilder, DataflowHugr},
extension::{
prelude::{bool_t, option_type},
PRELUDE_REGISTRY,
},
extension::prelude::{bool_t, option_type},
types::Signature,
};

Expand All @@ -102,7 +95,7 @@ mod tests {
let [opt] = builder.input_wires_arr();

let [res] = builder
.build_unwrap_sum(&PRELUDE_REGISTRY, 1, option_type(bool_t()), opt)
.build_unwrap_sum(1, option_type(bool_t()), opt)
.unwrap();
builder.finish_hugr_with_outputs([res]).unwrap();
}
Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,11 +471,11 @@ mod test {
let listy = list::list_type(usize_t());
let pop: ExtensionOp = list::ListOp::pop
.with_type(usize_t())
.to_extension_op(&reg)
.to_extension_op()
.unwrap();
let push: ExtensionOp = list::ListOp::push
.with_type(usize_t())
.to_extension_op(&reg)
.to_extension_op()
.unwrap();
let just_list = TypeRow::from(vec![listy.clone()]);
let intermed = TypeRow::from(vec![listy.clone(), usize_t()]);
Expand Down
33 changes: 8 additions & 25 deletions hugr-core/src/std_extensions/collections/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::types::{TypeName, TypeRowRV};
use crate::{
extension::{
simple_op::{MakeExtensionOp, OpLoadError},
ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef, TypeDefBound,
ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound,
},
ops::constant::CustomConst,
ops::{custom::ExtensionOp, NamedOp},
Expand Down Expand Up @@ -373,20 +373,8 @@ impl MakeExtensionOp for ListOpInst {
impl ListOpInst {
/// Convert this list operation to an [`ExtensionOp`] by providing a
/// registry to validate the element type against.
pub fn to_extension_op(self, elem_type_registry: &ExtensionRegistry) -> Option<ExtensionOp> {
let registry = ExtensionRegistry::new(
elem_type_registry
.clone()
.into_iter()
// ignore self if already in registry
.filter(|ext| ext.name() != EXTENSION.name())
.chain(std::iter::once(EXTENSION.to_owned())),
);
ExtensionOp::new(
registry.get(&EXTENSION_ID)?.get_op(&self.name())?.clone(),
self.type_args(),
)
.ok()
pub fn to_extension_op(self) -> Option<ExtensionOp> {
ExtensionOp::new(EXTENSION.get_op(&self.name())?.clone(), self.type_args()).ok()
}
}

Expand All @@ -398,14 +386,10 @@ mod test {
const_fail_tuple, const_none, const_ok_tuple, const_some_tuple,
};
use crate::ops::OpTrait;
use crate::std_extensions::STD_REG;
use crate::PortIndex;
use crate::{
extension::{
prelude::{qb_t, usize_t, ConstUsize},
PRELUDE,
},
std_extensions::arithmetic::float_types::{self, float64_type, ConstF64},
extension::prelude::{qb_t, usize_t, ConstUsize},
std_extensions::arithmetic::float_types::{float64_type, ConstF64},
types::TypeRow,
};

Expand Down Expand Up @@ -443,9 +427,8 @@ mod test {

#[test]
fn test_list_ops() {
let reg = ExtensionRegistry::new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]);
let pop_op = ListOp::pop.with_type(qb_t());
let pop_ext = pop_op.clone().to_extension_op(&reg).unwrap();
let pop_ext = pop_op.clone().to_extension_op().unwrap();
assert_eq!(ListOpInst::from_extension_op(&pop_ext).unwrap(), pop_op);
let pop_sig = pop_ext.dataflow_signature().unwrap();

Expand All @@ -457,7 +440,7 @@ mod test {
assert_eq!(pop_sig.output(), &both_row);

let push_op = ListOp::push.with_type(float64_type());
let push_ext = push_op.clone().to_extension_op(&reg).unwrap();
let push_ext = push_op.clone().to_extension_op().unwrap();
assert_eq!(ListOpInst::from_extension_op(&push_ext).unwrap(), push_op);
let push_sig = push_ext.dataflow_signature().unwrap();

Expand Down Expand Up @@ -531,7 +514,7 @@ mod test {

let res = op
.with_type(usize_t())
.to_extension_op(&STD_REG)
.to_extension_op()
.unwrap()
.constant_fold(&consts)
.unwrap();
Expand Down
24 changes: 11 additions & 13 deletions hugr-llvm/src/extension/collections/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ mod test {
let hugr = SimpleHugrConfig::new()
.with_outs(usize_t())
.with_extensions(exec_registry())
.finish_with_exts(|mut builder, reg| {
.finish(|mut builder| {
let us0 = builder.add_load_value(ConstUsize::new(0));
let us1 = builder.add_load_value(ConstUsize::new(1));
let i1 = builder.add_load_value(ConstInt::new_u(3, 1).unwrap());
Expand Down Expand Up @@ -872,13 +872,13 @@ mod test {
let [arr_0] = {
let r = builder.add_array_get(int_ty.clone(), 2, arr, us0).unwrap();
builder
.build_unwrap_sum(reg, 1, option_type(int_ty.clone()), r)
.build_unwrap_sum(1, option_type(int_ty.clone()), r)
.unwrap()
};
let [arr_1] = {
let r = builder.add_array_get(int_ty.clone(), 2, arr, us1).unwrap();
builder
.build_unwrap_sum(reg, 1, option_type(int_ty.clone()), r)
.build_unwrap_sum(1, option_type(int_ty.clone()), r)
.unwrap()
};
let elem_eq = builder.add_ieq(3, elem, expected_elem).unwrap();
Expand Down Expand Up @@ -943,7 +943,7 @@ mod test {
let hugr = SimpleHugrConfig::new()
.with_outs(usize_t())
.with_extensions(exec_registry())
.finish_with_exts(|mut builder, reg| {
.finish(|mut builder| {
let us0 = builder.add_load_value(ConstUsize::new(0));
let us1 = builder.add_load_value(ConstUsize::new(1));
let i1 = builder.add_load_value(ConstInt::new_u(3, 1).unwrap());
Expand Down Expand Up @@ -983,13 +983,13 @@ mod test {
let elem_0 = {
let r = builder.add_array_get(int_ty.clone(), 2, arr, us0).unwrap();
builder
.build_unwrap_sum::<1>(reg, 1, option_type(int_ty.clone()), r)
.build_unwrap_sum::<1>(1, option_type(int_ty.clone()), r)
.unwrap()[0]
};
let elem_1 = {
let r = builder.add_array_get(int_ty.clone(), 2, arr, us1).unwrap();
builder
.build_unwrap_sum::<1>(reg, 1, option_type(int_ty), r)
.build_unwrap_sum::<1>(1, option_type(int_ty), r)
.unwrap()[0]
};
let expected_elem_0 =
Expand Down Expand Up @@ -1052,7 +1052,7 @@ mod test {
let hugr = SimpleHugrConfig::new()
.with_outs(int_ty.clone())
.with_extensions(exec_registry())
.finish_with_exts(|mut builder, reg| {
.finish(|mut builder| {
let mut r = builder.add_load_value(ConstInt::new_u(6, 0).unwrap());
let new_array_args = array_contents
.iter()
Expand All @@ -1074,7 +1074,6 @@ mod test {
};
let [elem, new_arr] = builder
.build_unwrap_sum(
reg,
1,
option_type(vec![
int_ty.clone(),
Expand Down Expand Up @@ -1117,7 +1116,7 @@ mod test {
let hugr = SimpleHugrConfig::new()
.with_outs(int_ty.clone())
.with_extensions(exec_registry())
.finish_with_exts(|mut builder, reg| {
.finish(|mut builder| {
let mut func = builder
.define_function(
"foo",
Expand All @@ -1138,7 +1137,7 @@ mod test {
.add_array_get(int_ty.clone(), size, arr, idx_v)
.unwrap();
let [elem] = builder
.build_unwrap_sum(reg, 1, option_type(vec![int_ty.clone()]), get_res)
.build_unwrap_sum(1, option_type(vec![int_ty.clone()]), get_res)
.unwrap();
builder.finish_with_outputs([elem]).unwrap()
});
Expand All @@ -1164,7 +1163,7 @@ mod test {
let hugr = SimpleHugrConfig::new()
.with_outs(int_ty.clone())
.with_extensions(exec_registry())
.finish_with_exts(|mut builder, reg| {
.finish(|mut builder| {
let mut r = builder.add_load_value(ConstInt::new_u(6, 0).unwrap());
let new_array_args = (0..size)
.map(|i| builder.add_load_value(ConstInt::new_u(6, i).unwrap()))
Expand Down Expand Up @@ -1204,7 +1203,6 @@ mod test {
.unwrap();
let [elem, new_arr] = builder
.build_unwrap_sum(
reg,
1,
option_type(vec![
int_ty.clone(),
Expand Down Expand Up @@ -1240,7 +1238,7 @@ mod test {
let hugr = SimpleHugrConfig::new()
.with_outs(int_ty.clone())
.with_extensions(exec_registry())
.finish_with_exts(|mut builder, _reg| {
.finish(|mut builder| {
let new_array_args = (0..size)
.map(|i| builder.add_load_value(ConstInt::new_u(6, i).unwrap()))
.collect_vec();
Expand Down
29 changes: 6 additions & 23 deletions hugr-llvm/src/utils/array_op_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,7 @@ pub mod test {
use hugr_core::std_extensions::collections::array::{self, array_type};
use hugr_core::{
builder::{DFGBuilder, HugrBuilder},
extension::{
prelude::{either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _},
PRELUDE_REGISTRY,
},
extension::prelude::{either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _},
types::Signature,
Hugr,
};
Expand Down Expand Up @@ -149,15 +146,13 @@ pub mod test {
let array_type = array_type(2, usize_t());
either_type(array_type.clone(), array_type)
};
builder
.build_unwrap_sum(&PRELUDE_REGISTRY, 1, res_sum_ty, r)
.unwrap()
builder.build_unwrap_sum(1, res_sum_ty, r).unwrap()
};

let [elem_0] = {
let r = builder.add_array_get(usize_t(), 2, arr, us0).unwrap();
builder
.build_unwrap_sum(&PRELUDE_REGISTRY, 1, option_type(usize_t()), r)
.build_unwrap_sum(1, option_type(usize_t()), r)
.unwrap()
};

Expand All @@ -169,31 +164,19 @@ pub mod test {
let row = vec![usize_t(), array_type(2, usize_t())];
either_type(row.clone(), row)
};
builder
.build_unwrap_sum(&PRELUDE_REGISTRY, 1, res_sum_ty, r)
.unwrap()
builder.build_unwrap_sum(1, res_sum_ty, r).unwrap()
};

let [_elem_left, arr] = {
let r = builder.add_array_pop_left(usize_t(), 2, arr).unwrap();
builder
.build_unwrap_sum(
&PRELUDE_REGISTRY,
1,
option_type(vec![usize_t(), array_type(1, usize_t())]),
r,
)
.build_unwrap_sum(1, option_type(vec![usize_t(), array_type(1, usize_t())]), r)
.unwrap()
};
let [_elem_right, arr] = {
let r = builder.add_array_pop_right(usize_t(), 1, arr).unwrap();
builder
.build_unwrap_sum(
&PRELUDE_REGISTRY,
1,
option_type(vec![usize_t(), array_type(0, usize_t())]),
r,
)
.build_unwrap_sum(1, option_type(vec![usize_t(), array_type(0, usize_t())]), r)
.unwrap()
};

Expand Down
11 changes: 2 additions & 9 deletions hugr-passes/src/const_fold/test.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::collections::hash_map::RandomState;
use std::collections::HashSet;

use hugr_core::std_extensions::STD_REG;
use itertools::Itertools;
use lazy_static::lazy_static;
use rstest::rstest;
Expand Down Expand Up @@ -184,10 +183,7 @@ fn test_list_ops() -> Result<(), Box<dyn std::error::Error>> {

let [list, maybe_elem] = build
.add_dataflow_op(
ListOp::pop
.with_type(bool_t())
.to_extension_op(&STD_REG)
.unwrap(),
ListOp::pop.with_type(bool_t()).to_extension_op().unwrap(),
[list],
)?
.outputs_arr();
Expand All @@ -197,10 +193,7 @@ fn test_list_ops() -> Result<(), Box<dyn std::error::Error>> {

let [list] = build
.add_dataflow_op(
ListOp::push
.with_type(bool_t())
.to_extension_op(&STD_REG)
.unwrap(),
ListOp::push.with_type(bool_t()).to_extension_op().unwrap(),
[list, elem],
)?
.outputs_arr();
Expand Down
Loading

0 comments on commit 7cf7bb6

Please sign in to comment.