Skip to content

Commit

Permalink
Replace ExtensionRegistry::try_new with ::validate and cache the …
Browse files Browse the repository at this point in the history
…result
  • Loading branch information
aborgna-q committed Dec 9, 2024
1 parent 95e10e3 commit 3220bf6
Show file tree
Hide file tree
Showing 27 changed files with 153 additions and 133 deletions.
109 changes: 72 additions & 37 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::collections::btree_map;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::Debug;
use std::mem;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Weak};

use derive_more::Display;
Expand Down Expand Up @@ -43,14 +44,36 @@ pub use type_def::{TypeDef, TypeDefBound};
pub mod declarative;

/// Extension Registries store extensions to be looked up e.g. during validation.
#[derive(Clone, Debug, Display, Default, PartialEq)]
#[display("ExtensionRegistry[{}]", _0.keys().join(", "))]
pub struct ExtensionRegistry(BTreeMap<ExtensionId, Arc<Extension>>);
#[derive(Debug, Display, Default)]
#[display("ExtensionRegistry[{}]", exts.keys().join(", "))]
pub struct ExtensionRegistry {
/// The extensions in the registry.
exts: BTreeMap<ExtensionId, Arc<Extension>>,
/// A flag indicating whether the current set of extensions has been
/// validated.
///
/// This is used to avoid re-validating the extensions every time they are
/// used, and is set to `false` whenever a new extension is added.
valid: AtomicBool,
}

impl PartialEq for ExtensionRegistry {
fn eq(&self, other: &Self) -> bool {
self.exts == other.exts
}
}

impl Clone for ExtensionRegistry {
fn clone(&self) -> Self {
Self {
exts: self.exts.clone(),
valid: self.valid.load(Ordering::Relaxed).into(),
}
}
}

impl ExtensionRegistry {
/// Create a new empty extension registry.
///
/// For a version that checks the validity of the extensions, see [`ExtensionRegistry::try_new`].
pub fn new(extensions: impl IntoIterator<Item = Arc<Extension>>) -> Self {
let mut res = Self::default();
for ext in extensions.into_iter() {
Expand All @@ -61,33 +84,32 @@ impl ExtensionRegistry {

/// Gets the Extension with the given name
pub fn get(&self, name: &str) -> Option<&Arc<Extension>> {
self.0.get(name)
self.exts.get(name)
}

/// Returns `true` if the registry contains an extension with the given name.
pub fn contains(&self, name: &str) -> bool {
self.0.contains_key(name)
self.exts.contains_key(name)
}

/// Makes a new [ExtensionRegistry], validating all the extensions in it.
/// Validate the set of extensions, ensuring that each extension requirements are also in the registry.
///
/// For an unvalidated version, see [`ExtensionRegistry::new`].
pub fn try_new(
value: impl IntoIterator<Item = Arc<Extension>>,
) -> Result<Self, ExtensionRegistryError> {
let res = ExtensionRegistry::new(value);

// Note this potentially asks extensions to validate themselves against other extensions that
// may *not* be valid themselves yet. It'd be better to order these respecting dependencies,
// or at least to validate the types first - which we don't do at all yet:
// TODO https://github.com/CQCL/hugr/issues/624. However, parametrized types could be
// cyclically dependent, so there is no perfect solution, and this is at least simple.
for ext in res.0.values() {
ext.validate(&res)
/// Note this potentially asks extensions to validate themselves against other extensions that
/// may *not* be valid themselves yet. It'd be better to order these respecting dependencies,
/// or at least to validate the types first - which we don't do at all yet:
//
// TODO https://github.com/CQCL/hugr/issues/624. However, parametrized types could be
// cyclically dependent, so there is no perfect solution, and this is at least simple.
pub fn validate(&self) -> Result<(), ExtensionRegistryError> {
if self.valid.load(Ordering::Relaxed) {
return Ok(());
}
for ext in self.exts.values() {
ext.validate(self)
.map_err(|e| ExtensionRegistryError::InvalidSignature(ext.name().clone(), e))?;
}

Ok(res)
self.valid.store(true, Ordering::Relaxed);
Ok(())
}

/// Registers a new extension to the registry.
Expand All @@ -98,14 +120,17 @@ impl ExtensionRegistry {
extension: impl Into<Arc<Extension>>,
) -> Result<(), ExtensionRegistryError> {
let extension = extension.into();
match self.0.entry(extension.name().clone()) {
match self.exts.entry(extension.name().clone()) {
btree_map::Entry::Occupied(prev) => Err(ExtensionRegistryError::AlreadyRegistered(
extension.name().clone(),
prev.get().version().clone(),
extension.version().clone(),
)),
btree_map::Entry::Vacant(ve) => {
ve.insert(extension);
// Clear the valid flag so that the registry is re-validated.
self.valid.store(false, Ordering::Relaxed);

Ok(())
}
}
Expand All @@ -122,7 +147,7 @@ impl ExtensionRegistry {
/// see [`ExtensionRegistry::register_updated_ref`].
pub fn register_updated(&mut self, extension: impl Into<Arc<Extension>>) {
let extension = extension.into();
match self.0.entry(extension.name().clone()) {
match self.exts.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
*prev.get_mut() = extension;
Expand All @@ -132,6 +157,8 @@ impl ExtensionRegistry {
ve.insert(extension);
}
}
// Clear the valid flag so that the registry is re-validated.
self.valid.store(false, Ordering::Relaxed);
}

/// Registers a new extension to the registry, keeping the one most up to
Expand All @@ -144,7 +171,7 @@ impl ExtensionRegistry {
/// Clones the Arc only when required. For no-cloning version see
/// [`ExtensionRegistry::register_updated`].
pub fn register_updated_ref(&mut self, extension: &Arc<Extension>) {
match self.0.entry(extension.name().clone()) {
match self.exts.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
*prev.get_mut() = extension.clone();
Expand All @@ -154,31 +181,36 @@ impl ExtensionRegistry {
ve.insert(extension.clone());
}
}
// Clear the valid flag so that the registry is re-validated.
self.valid.store(false, Ordering::Relaxed);
}

/// Returns the number of extensions in the registry.
pub fn len(&self) -> usize {
self.0.len()
self.exts.len()
}

/// Returns `true` if the registry contains no extensions.
pub fn is_empty(&self) -> bool {
self.0.is_empty()
self.exts.is_empty()
}

/// Returns an iterator over the extensions in the registry.
pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter {
self.0.values()
self.exts.values()
}

/// Returns an iterator over the extensions ids in the registry.
pub fn ids(&self) -> impl Iterator<Item = &ExtensionId> {
self.0.keys()
self.exts.keys()
}

/// Delete an extension from the registry and return it if it was present.
pub fn remove_extension(&mut self, name: &ExtensionId) -> Option<Arc<Extension>> {
self.0.remove(name)
// Clear the valid flag so that the registry is re-validated.
self.valid.store(false, Ordering::Relaxed);

self.exts.remove(name)
}
}

Expand All @@ -188,7 +220,7 @@ impl IntoIterator for ExtensionRegistry {
type IntoIter = std::collections::btree_map::IntoValues<ExtensionId, Arc<Extension>>;

fn into_iter(self) -> Self::IntoIter {
self.0.into_values()
self.exts.into_values()
}
}

Expand All @@ -198,7 +230,7 @@ impl<'a> IntoIterator for &'a ExtensionRegistry {
type IntoIter = std::collections::btree_map::Values<'a, ExtensionId, Arc<Extension>>;

fn into_iter(self) -> Self::IntoIter {
self.0.values()
self.exts.values()
}
}

Expand Down Expand Up @@ -235,13 +267,16 @@ impl Serialize for ExtensionRegistry {
where
S: serde::Serializer,
{
let extensions: Vec<Arc<Extension>> = self.0.values().cloned().collect();
let extensions: Vec<Arc<Extension>> = self.exts.values().cloned().collect();
extensions.serialize(serializer)
}
}

/// An Extension Registry containing no extensions.
pub const EMPTY_REG: ExtensionRegistry = ExtensionRegistry(BTreeMap::new());
pub static EMPTY_REG: ExtensionRegistry = ExtensionRegistry {
exts: BTreeMap::new(),
valid: AtomicBool::new(true),
};

/// An error that can occur in computing the signature of a node.
/// TODO: decide on failure modes
Expand Down Expand Up @@ -815,8 +850,8 @@ pub mod test {
fn test_register_update() {
// Two registers that should remain the same.
// We use them to test both `register_updated` and `register_updated_ref`.
let mut reg = ExtensionRegistry::try_new([]).unwrap();
let mut reg_ref = ExtensionRegistry::try_new([]).unwrap();
let mut reg = ExtensionRegistry::default();
let mut reg_ref = ExtensionRegistry::default();

let ext_1_id = ExtensionId::new("ext1").unwrap();
let ext_2_id = ExtensionId::new("ext2").unwrap();
Expand Down
3 changes: 2 additions & 1 deletion hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,8 @@ pub(super) mod test {
Ok(())
})?;

let reg = ExtensionRegistry::try_new([PRELUDE.clone(), EXTENSION.clone(), ext]).unwrap();
let reg = ExtensionRegistry::new([PRELUDE.clone(), EXTENSION.clone(), ext]);
reg.validate()?;
let e = reg.get(&EXT_ID).unwrap();

let list_usize =
Expand Down
3 changes: 1 addition & 2 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ lazy_static! {
};

/// An extension registry containing only the prelude
pub static ref PRELUDE_REGISTRY: ExtensionRegistry =
ExtensionRegistry::try_new([PRELUDE.clone()]).unwrap();
pub static ref PRELUDE_REGISTRY: ExtensionRegistry = ExtensionRegistry::new([PRELUDE.clone()]);
}

pub(crate) fn usize_custom_t(extension_ref: &Weak<Extension>) -> CustomType {
Expand Down
3 changes: 1 addition & 2 deletions hugr-core/src/extension/simple_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,7 @@ mod test {
.unwrap();
})
};
static ref DUMMY_REG: ExtensionRegistry =
ExtensionRegistry::try_new([EXT.clone()]).unwrap();
static ref DUMMY_REG: ExtensionRegistry = ExtensionRegistry::new([EXT.clone()]);
}
impl MakeRegisteredOp for DummyEnum {
fn extension_id(&self) -> ExtensionId {
Expand Down
18 changes: 9 additions & 9 deletions hugr-core/src/hugr/rewrite/inline_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,12 @@ mod test {
fn inline_add_load_const(#[case] nonlocal: bool) -> Result<(), Box<dyn std::error::Error>> {
use crate::extension::prelude::Lift;

let reg = ExtensionRegistry::try_new([
let reg = ExtensionRegistry::new([
PRELUDE.to_owned(),
int_ops::EXTENSION.to_owned(),
int_types::EXTENSION.to_owned(),
])
.unwrap();
]);
reg.validate()?;
let int_ty = &int_types::INT_TYPES[6];

let mut outer = DFGBuilder::new(inout_sig(vec![int_ty.clone(); 2], vec![int_ty.clone()]))?;
Expand Down Expand Up @@ -256,12 +256,12 @@ mod test {
};
let [q, p] = swap.outputs_arr();
let cx = h.add_dataflow_op(test_quantum_extension::cx_gate(), [q, p])?;
let reg = ExtensionRegistry::try_new([
let reg = ExtensionRegistry::new([
test_quantum_extension::EXTENSION.clone(),
PRELUDE.clone(),
float_types::EXTENSION.clone(),
])
.unwrap();
]);
reg.validate()?;

let mut h = h.finish_hugr_with_outputs(cx.outputs(), &reg)?;
assert_eq!(find_dfgs(&h), vec![h.root(), swap.node()]);
Expand Down Expand Up @@ -333,12 +333,12 @@ mod test {
* CX
*/
// Extension inference here relies on quantum ops not requiring their own test_quantum_extension
let reg = ExtensionRegistry::try_new([
let reg = ExtensionRegistry::new([
test_quantum_extension::EXTENSION.to_owned(),
float_types::EXTENSION.to_owned(),
PRELUDE.to_owned(),
])
.unwrap();
]);
reg.validate()?;
let mut outer = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?;
let [a, b] = outer.input_wires_arr();
let h_a = outer.add_dataflow_op(test_quantum_extension::h_gate(), [a])?;
Expand Down
5 changes: 2 additions & 3 deletions hugr-core/src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,9 +470,8 @@ mod test {
#[test]
#[ignore] // FIXME: This needs a rewrite now that `pop` returns an optional value -.-'
fn cfg() -> Result<(), Box<dyn std::error::Error>> {
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), collections::EXTENSION.to_owned()])
.unwrap();
let reg = ExtensionRegistry::new([PRELUDE.to_owned(), collections::EXTENSION.to_owned()]);
reg.validate()?;
let listy = list_type(usize_t());
let pop: ExtensionOp = ListOp::pop
.with_type(usize_t())
Expand Down
13 changes: 8 additions & 5 deletions hugr-core/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,8 @@ fn invalid_types() {
)
.unwrap();
});
let reg = ExtensionRegistry::try_new([ext.clone(), PRELUDE.clone()]).unwrap();
let reg = ExtensionRegistry::new([ext.clone(), PRELUDE.clone()]);
reg.validate().unwrap();

let validate_to_sig_error = |t: CustomType| {
let (h, def) = identity_hugr_with_type(Type::new_extension(t));
Expand Down Expand Up @@ -569,7 +570,8 @@ fn no_polymorphic_consts() -> Result<(), Box<dyn std::error::Error>> {
.unwrap()
.instantiate(vec![TypeArg::new_var_use(0, BOUND)])?,
);
let reg = ExtensionRegistry::try_new([collections::EXTENSION.to_owned()]).unwrap();
let reg = ExtensionRegistry::new([collections::EXTENSION.to_owned()]);
reg.validate()?;
let mut def = FunctionBuilder::new(
"myfunc",
PolyFuncType::new(
Expand Down Expand Up @@ -653,7 +655,7 @@ fn instantiate_row_variables() -> Result<(), Box<dyn std::error::Error>> {
let eval2 = dfb.add_dataflow_op(eval2, [par_func, a, b])?;
dfb.finish_hugr_with_outputs(
eval2.outputs(),
&ExtensionRegistry::try_new([PRELUDE.clone(), e]).unwrap(),
&ExtensionRegistry::new([PRELUDE.clone(), e]),
)?;
Ok(())
}
Expand Down Expand Up @@ -693,7 +695,7 @@ fn row_variables() -> Result<(), Box<dyn std::error::Error>> {
let par_func = fb.add_dataflow_op(par, [func_arg, id_usz])?;
fb.finish_hugr_with_outputs(
par_func.outputs(),
&ExtensionRegistry::try_new([PRELUDE.clone(), e]).unwrap(),
&ExtensionRegistry::new([PRELUDE.clone(), e]),
)?;
Ok(())
}
Expand Down Expand Up @@ -780,7 +782,8 @@ fn test_polymorphic_call() -> Result<(), Box<dyn std::error::Error>> {
f.finish_with_outputs([tup])?
};

let reg = ExtensionRegistry::try_new([e, PRELUDE.clone()])?;
let reg = ExtensionRegistry::new([e, PRELUDE.clone()]);
reg.validate()?;
let [func, tup] = d.input_wires_arr();
let call = d.call(
f.handle(),
Expand Down
Loading

0 comments on commit 3220bf6

Please sign in to comment.