Skip to content

Commit

Permalink
feat: Share Extensions under Arcs
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Nov 25, 2024
1 parent 1be116e commit 143a155
Show file tree
Hide file tree
Showing 23 changed files with 138 additions and 98 deletions.
4 changes: 3 additions & 1 deletion hugr-cli/tests/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
//! calling the CLI binary, which Miri doesn't support.
#![cfg(all(test, not(miri)))]

use std::sync::Arc;

use assert_cmd::Command;
use assert_fs::{fixture::FileWriteStr, NamedTempFile};
use hugr::builder::{DFGBuilder, DataflowSubContainer, ModuleBuilder};
Expand Down Expand Up @@ -49,7 +51,7 @@ fn test_package(#[default(BOOL_T)] id_type: Type) -> Package {
let hugr = module.hugr().clone(); // unvalidated

let rdr = std::fs::File::open(FLOAT_EXT_FILE).unwrap();
let float_ext: hugr::Extension = serde_json::from_reader(rdr).unwrap();
let float_ext: Arc<hugr::Extension> = serde_json::from_reader(rdr).unwrap();
Package::new(vec![hugr], vec![float_ext]).unwrap()
}

Expand Down
61 changes: 37 additions & 24 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ pub mod declarative;

/// Extension Registries store extensions to be looked up e.g. during validation.
#[derive(Clone, Debug, PartialEq)]
pub struct ExtensionRegistry(BTreeMap<ExtensionId, Extension>);
pub struct ExtensionRegistry(BTreeMap<ExtensionId, Arc<Extension>>);

impl ExtensionRegistry {
/// Gets the Extension with the given name
pub fn get(&self, name: &str) -> Option<&Extension> {
pub fn get(&self, name: &str) -> Option<&Arc<Extension>> {
self.0.get(name)
}

Expand All @@ -51,9 +51,9 @@ impl ExtensionRegistry {
self.0.contains_key(name)
}

/// Makes a new ExtensionRegistry, validating all the extensions in it
/// Makes a new [ExtensionRegistry], validating all the extensions in it.
pub fn try_new(
value: impl IntoIterator<Item = Extension>,
value: impl IntoIterator<Item = Arc<Extension>>,
) -> Result<Self, ExtensionRegistryError> {
let mut res = ExtensionRegistry(BTreeMap::new());

Expand All @@ -70,20 +70,28 @@ impl ExtensionRegistry {
ext.validate(&res)
.map_err(|e| ExtensionRegistryError::InvalidSignature(ext.name().clone(), e))?;
}

Ok(res)
}

/// Registers a new extension to the registry.
///
/// Returns a reference to the registered extension if successful.
pub fn register(&mut self, extension: Extension) -> Result<&Extension, ExtensionRegistryError> {
pub fn register(
&mut self,
extension: impl Into<Arc<Extension>>,
) -> Result<(), ExtensionRegistryError> {
let extension = extension.into();
match self.0.entry(extension.name().clone()) {
btree_map::Entry::Occupied(prev) => Err(ExtensionRegistryError::AlreadyRegistered(
extension.name().clone(),
prev.get().version().clone(),
extension.version().clone(),
)),
btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension)),
btree_map::Entry::Vacant(ve) => {
ve.insert(extension);
Ok(())
}
}
}

Expand All @@ -93,21 +101,24 @@ impl ExtensionRegistry {
/// If versions match, the original extension is kept.
/// Returns a reference to the registered extension if successful.
///
/// Avoids cloning the extension unless required. For a reference version see
/// Takes an Arc to the extension. To avoid cloning Arcs unless necessary, see
/// [`ExtensionRegistry::register_updated_ref`].
pub fn register_updated(
&mut self,
extension: Extension,
) -> Result<&Extension, ExtensionRegistryError> {
extension: impl Into<Arc<Extension>>,
) -> Result<(), ExtensionRegistryError> {
let extension = extension.into();
match self.0.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
*prev.get_mut() = extension;
}
Ok(prev.into_mut())
}
btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension)),
btree_map::Entry::Vacant(ve) => {
ve.insert(extension);
}
}
Ok(())
}

/// Registers a new extension to the registry, keeping most up to date if
Expand All @@ -117,21 +128,23 @@ impl ExtensionRegistry {
/// If versions match, the original extension is kept. Returns a reference
/// to the registered extension if successful.
///
/// Clones the extension if required. For no-cloning version see
/// Clones the Arc only when required. For no-cloning version see
/// [`ExtensionRegistry::register_updated`].
pub fn register_updated_ref(
&mut self,
extension: &Extension,
) -> Result<&Extension, ExtensionRegistryError> {
extension: &Arc<Extension>,
) -> Result<(), ExtensionRegistryError> {
match self.0.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
*prev.get_mut() = extension.clone();
}
Ok(prev.into_mut())
}
btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension.clone())),
btree_map::Entry::Vacant(ve) => {
ve.insert(extension.clone());
}
}
Ok(())
}

/// Returns the number of extensions in the registry.
Expand All @@ -145,20 +158,20 @@ impl ExtensionRegistry {
}

/// Returns an iterator over the extensions in the registry.
pub fn iter(&self) -> impl Iterator<Item = (&ExtensionId, &Extension)> {
pub fn iter(&self) -> impl Iterator<Item = (&ExtensionId, &Arc<Extension>)> {
self.0.iter()
}

/// Delete an extension from the registry and return it if it was present.
pub fn remove_extension(&mut self, name: &ExtensionId) -> Option<Extension> {
pub fn remove_extension(&mut self, name: &ExtensionId) -> Option<Arc<Extension>> {
self.0.remove(name)
}
}

impl IntoIterator for ExtensionRegistry {
type Item = (ExtensionId, Extension);
type Item = (ExtensionId, Arc<Extension>);

type IntoIter = <BTreeMap<ExtensionId, Extension> as IntoIterator>::IntoIter;
type IntoIter = <BTreeMap<ExtensionId, Arc<Extension>> as IntoIterator>::IntoIter;

fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
Expand Down Expand Up @@ -646,10 +659,10 @@ pub mod test {

let ext_1_id = ExtensionId::new("ext1").unwrap();
let ext_2_id = ExtensionId::new("ext2").unwrap();
let ext1 = Extension::new(ext_1_id.clone(), Version::new(1, 0, 0));
let ext1_1 = Extension::new(ext_1_id.clone(), Version::new(1, 1, 0));
let ext1_2 = Extension::new(ext_1_id.clone(), Version::new(0, 2, 0));
let ext2 = Extension::new(ext_2_id, Version::new(1, 0, 0));
let ext1 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(1, 0, 0)));
let ext1_1 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(1, 1, 0)));
let ext1_2 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(0, 2, 0)));
let ext2 = Arc::new(Extension::new(ext_2_id, Version::new(1, 0, 0)));

reg.register(ext1.clone()).unwrap();
reg_ref.register(ext1.clone()).unwrap();
Expand Down
7 changes: 4 additions & 3 deletions hugr-core/src/extension/declarative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ impl ExtensionSetDeclaration {
registry,
};
let ext = decl.make_extension(&self.imports, ctx)?;
let ext = registry.register(ext)?;
scope.insert(ext.name())
scope.insert(ext.name());
registry.register(ext)?;
}

Ok(())
Expand Down Expand Up @@ -272,6 +272,7 @@ mod test {
use itertools::Itertools;
use rstest::rstest;
use std::path::PathBuf;
use std::sync::Arc;

use crate::extension::PRELUDE_REGISTRY;
use crate::std_extensions;
Expand Down Expand Up @@ -406,7 +407,7 @@ extensions:
fn new_extensions<'a>(
reg: &'a ExtensionRegistry,
dependencies: &'a ExtensionRegistry,
) -> impl Iterator<Item = (&'a ExtensionId, &'a Extension)> {
) -> impl Iterator<Item = (&'a ExtensionId, &'a Arc<Extension>)> {
reg.iter()
.filter(move |(id, _)| !dependencies.contains(id) && *id != &PRELUDE_ID)
}
Expand Down
2 changes: 1 addition & 1 deletion hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ pub(super) mod test {
assert_eq!(def.misc.len(), 1);

let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), EXTENSION.to_owned(), e]).unwrap();
ExtensionRegistry::try_new([PRELUDE.clone(), EXTENSION.clone(), e.into()]).unwrap();
let e = reg.get(&EXT_ID).unwrap();

let list_usize =
Expand Down
11 changes: 7 additions & 4 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Prelude extension - available in all contexts, defining common types,
//! operations and constants.
use std::sync::Arc;

use itertools::Itertools;
use lazy_static::lazy_static;

Expand Down Expand Up @@ -38,7 +40,7 @@ pub const PRELUDE_ID: ExtensionId = ExtensionId::new_unchecked("prelude");
/// Extension version.
pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
lazy_static! {
static ref PRELUDE_DEF: Extension = {
static ref PRELUDE_DEF: Arc<Extension> = {
let mut prelude = Extension::new(PRELUDE_ID, VERSION);
prelude
.add_type(
Expand Down Expand Up @@ -106,14 +108,15 @@ lazy_static! {
LiftDef.add_to_extension(&mut prelude).unwrap();
array::ArrayOpDef::load_all_ops(&mut prelude).unwrap();
array::ArrayScanDef.add_to_extension(&mut prelude).unwrap();
prelude

Arc::new(prelude)
};
/// An extension registry containing only the prelude
pub static ref PRELUDE_REGISTRY: ExtensionRegistry =
ExtensionRegistry::try_new([PRELUDE_DEF.to_owned()]).unwrap();
ExtensionRegistry::try_new([PRELUDE_DEF.clone()]).unwrap();

/// Prelude extension
pub static ref PRELUDE: &'static Extension = PRELUDE_REGISTRY.get(&PRELUDE_ID).unwrap();
pub static ref PRELUDE: Arc<Extension> = PRELUDE_REGISTRY.get(&PRELUDE_ID).unwrap().clone();

}

Expand Down
8 changes: 5 additions & 3 deletions hugr-core/src/extension/simple_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ impl<T: MakeRegisteredOp> From<T> for OpType {

#[cfg(test)]
mod test {
use std::sync::Arc;

use crate::{const_extension_ids, type_row, types::Signature};

use super::*;
Expand Down Expand Up @@ -313,13 +315,13 @@ mod test {
}

lazy_static! {
static ref EXT: Extension = {
static ref EXT: Arc<Extension> = {
let mut e = Extension::new_test(EXT_ID.clone());
DummyEnum::Dumb.add_to_extension(&mut e).unwrap();
e
Arc::new(e)
};
static ref DUMMY_REG: ExtensionRegistry =
ExtensionRegistry::try_new([EXT.to_owned()]).unwrap();
ExtensionRegistry::try_new([EXT.clone()]).unwrap();
}
impl MakeRegisteredOp for DummyEnum {
fn extension_id(&self) -> ExtensionId {
Expand Down
6 changes: 3 additions & 3 deletions hugr-core/src/hugr/rewrite/inline_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@ mod test {
let [q, p] = swap.outputs_arr();
let cx = h.add_dataflow_op(test_quantum_extension::cx_gate(), [q, p])?;
let reg = ExtensionRegistry::try_new([
test_quantum_extension::EXTENSION.to_owned(),
PRELUDE.to_owned(),
float_types::EXTENSION.to_owned(),
test_quantum_extension::EXTENSION.clone(),
PRELUDE.clone(),
float_types::EXTENSION.clone(),
])
.unwrap();

Expand Down
8 changes: 4 additions & 4 deletions hugr-core/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ fn invalid_types() {
TypeDefBound::any(),
)
.unwrap();
let reg = ExtensionRegistry::try_new([e, PRELUDE.to_owned()]).unwrap();
let reg = ExtensionRegistry::try_new([e.into(), PRELUDE.clone()]).unwrap();

let validate_to_sig_error = |t: CustomType| {
let (h, def) = identity_hugr_with_type(Type::new_extension(t));
Expand Down Expand Up @@ -643,7 +643,7 @@ fn instantiate_row_variables() -> Result<(), Box<dyn std::error::Error>> {
let eval2 = dfb.add_dataflow_op(eval2, [par_func, a, b])?;
dfb.finish_hugr_with_outputs(
eval2.outputs(),
&ExtensionRegistry::try_new([PRELUDE.to_owned(), e]).unwrap(),
&ExtensionRegistry::try_new([PRELUDE.clone(), e.into()]).unwrap(),
)?;
Ok(())
}
Expand Down Expand Up @@ -683,7 +683,7 @@ fn row_variables() -> Result<(), Box<dyn std::error::Error>> {
let par_func = fb.add_dataflow_op(par, [func_arg, id_usz])?;
fb.finish_hugr_with_outputs(
par_func.outputs(),
&ExtensionRegistry::try_new([PRELUDE.to_owned(), e]).unwrap(),
&ExtensionRegistry::try_new([PRELUDE.clone(), e.into()]).unwrap(),
)?;
Ok(())
}
Expand Down Expand Up @@ -763,7 +763,7 @@ fn test_polymorphic_call() -> Result<(), Box<dyn std::error::Error>> {
f.finish_with_outputs([tup])?
};

let reg = ExtensionRegistry::try_new([e, PRELUDE.to_owned()])?;
let reg = ExtensionRegistry::try_new([e.into(), PRELUDE.clone()])?;
let [func, tup] = d.input_wires_arr();
let call = d.call(
f.handle(),
Expand Down
6 changes: 3 additions & 3 deletions hugr-core/src/ops/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl ExtensionOp {
args: impl Into<Vec<TypeArg>>,
exts: &ExtensionRegistry,
) -> Result<Self, SignatureError> {
let args = args.into();
let args: Vec<TypeArg> = args.into();
let signature = def.compute_signature(&args, exts)?;
Ok(Self {
def,
Expand All @@ -62,7 +62,7 @@ impl ExtensionOp {
opaque: &OpaqueOp,
exts: &ExtensionRegistry,
) -> Result<Self, SignatureError> {
let args = args.into();
let args: Vec<TypeArg> = args.into();
// TODO skip computation depending on config
// see https://github.com/CQCL/hugr/issues/1363
let signature = match def.compute_signature(&args, exts) {
Expand Down Expand Up @@ -421,7 +421,7 @@ mod test {
SignatureFunc::MissingComputeFunc,
)
.unwrap();
let registry = ExtensionRegistry::try_new([ext]).unwrap();
let registry = ExtensionRegistry::try_new([ext.into()]).unwrap();
let opaque_val = OpaqueOp::new(
ext_id.clone(),
val_name,
Expand Down
9 changes: 5 additions & 4 deletions hugr-core/src/package.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use derive_more::{Display, Error, From};
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use std::{fs, io, mem};

use crate::builder::{Container, Dataflow, DataflowSubContainer, ModuleBuilder};
Expand All @@ -19,7 +20,7 @@ pub struct Package {
/// Module HUGRs included in the package.
pub modules: Vec<Hugr>,
/// Extensions to validate against.
pub extensions: Vec<Extension>,
pub extensions: Vec<Arc<Extension>>,
}

impl Package {
Expand All @@ -32,7 +33,7 @@ impl Package {
/// Returns an error if any of the HUGRs does not have a `Module` root.
pub fn new(
modules: impl IntoIterator<Item = Hugr>,
extensions: impl IntoIterator<Item = Extension>,
extensions: impl IntoIterator<Item = Arc<Extension>>,
) -> Result<Self, PackageError> {
let modules: Vec<Hugr> = modules.into_iter().collect();
for (idx, module) in modules.iter().enumerate() {
Expand Down Expand Up @@ -62,7 +63,7 @@ impl Package {
/// Returns an error if any of the HUGRs cannot be wrapped in a module.
pub fn from_hugrs(
modules: impl IntoIterator<Item = Hugr>,
extensions: impl IntoIterator<Item = Extension>,
extensions: impl IntoIterator<Item = Arc<Extension>>,
) -> Result<Self, PackageError> {
let modules: Vec<Hugr> = modules
.into_iter()
Expand Down Expand Up @@ -378,7 +379,7 @@ mod test {

Package {
modules: vec![hugr0, hugr1],
extensions: vec![ext1, ext2],
extensions: vec![ext1.into(), ext2.into()],
}
}

Expand Down
Loading

0 comments on commit 143a155

Please sign in to comment.