Skip to content

Commit

Permalink
refactor: validate ExtensionRegistry when built, not as we build it (#…
Browse files Browse the repository at this point in the history
…701)

Closes #676 
* Replace PolyFunc::new_validated with ::new to allow (a) building
without an ExtensionRegistry and (b) PolyFuncTypes with free variables
(necessary for nested instances)
* Update a bunch of code in std_extensions code and elsewhere using a
`temp_reg` or similar, this is no longer needed
* Change `[Extension1, Extension2, ...].into()` into
`ExtensionRegistry::try_new` which does validation and can fail with an
`(ExtensionId, SignatureError)`
* This required updating the test_quantum extension's registry to
include float_types
* Some commoning up of registries, and move `test_registry` to the one
file that actually used it (it doesn't seem so distinguished now we have
registries all over the place!)
  • Loading branch information
acl-cqc authored Nov 20, 2023
1 parent 41e15da commit 4a8d190
Show file tree
Hide file tree
Showing 14 changed files with 200 additions and 249 deletions.
43 changes: 29 additions & 14 deletions src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,33 +35,38 @@ pub use prelude::{PRELUDE, PRELUDE_REGISTRY};
pub struct ExtensionRegistry(BTreeMap<ExtensionId, Extension>);

impl ExtensionRegistry {
/// Makes a new (empty) registry.
pub const fn new() -> Self {
Self(BTreeMap::new())
}

/// Gets the Extension with the given name
pub fn get(&self, name: &str) -> Option<&Extension> {
self.0.get(name)
}
}

/// An Extension Registry containing no extensions.
pub const EMPTY_REG: ExtensionRegistry = ExtensionRegistry::new();

impl<T: IntoIterator<Item = Extension>> From<T> for ExtensionRegistry {
fn from(value: T) -> Self {
let mut reg = Self::new();
/// Makes a new ExtensionRegistry, validating all the extensions in it
pub fn try_new(
value: impl IntoIterator<Item = Extension>,
) -> Result<Self, (ExtensionId, SignatureError)> {
let mut exts = BTreeMap::new();
for ext in value.into_iter() {
let prev = reg.0.insert(ext.name.clone(), ext);
let prev = exts.insert(ext.name.clone(), ext);
if let Some(prev) = prev {
panic!("Multiple extensions with same name: {}", prev.name)
};
}
reg
// Note this potentially asks extensions to validate themselves against other extensions that
// may *not* be valid themselves yet. It'd be better to order these respecting dependencies,
// or at least to validate the types first - which we don't do at all yet:
// TODO https://github.com/CQCL/hugr/issues/624. However, parametrized types could be
// cyclically dependent, so there is no perfect solution, and this is at least simple.
let res = ExtensionRegistry(exts);
for ext in res.0.values() {
ext.validate(&res).map_err(|e| (ext.name().clone(), e))?;
}
Ok(res)
}
}

/// An Extension Registry containing no extensions.
pub const EMPTY_REG: ExtensionRegistry = ExtensionRegistry(BTreeMap::new());

/// An error that can occur in computing the signature of a node.
/// TODO: decide on failure modes
#[derive(Debug, Clone, Error, PartialEq, Eq)]
Expand Down Expand Up @@ -290,6 +295,16 @@ impl Extension {
let op_def = self.get_op(op_name).expect("Op not found.");
ExtensionOp::new(op_def.clone(), args, ext_reg)
}

// Validates against a registry, which we can assume includes this extension itself.
// (TODO deal with the registry itself containing invalid extensions!)
fn validate(&self, all_exts: &ExtensionRegistry) -> Result<(), SignatureError> {
// We should validate TypeParams of TypeDefs too - https://github.com/CQCL/hugr/issues/624
for op_def in self.operations.values() {
op_def.validate(all_exts)?;
}
Ok(())
}
}

impl PartialEq for Extension {
Expand Down
28 changes: 16 additions & 12 deletions src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,15 @@ impl OpDef {
SignatureFunc::CustomFunc { static_params, .. } => static_params,
}
}

pub(super) fn validate(&self, exts: &ExtensionRegistry) -> Result<(), SignatureError> {
// TODO https://github.com/CQCL/hugr/issues/624 validate declared TypeParams
// for both type scheme and custom binary
if let SignatureFunc::TypeScheme(ts) = &self.signature_func {
ts.validate(exts, &[])?;
}
Ok(())
}
}

impl Extension {
Expand Down Expand Up @@ -356,7 +365,7 @@ mod test {

use crate::builder::{DFGBuilder, Dataflow, DataflowHugr};
use crate::extension::prelude::USIZE_T;
use crate::extension::PRELUDE;
use crate::extension::{ExtensionRegistry, PRELUDE};
use crate::ops::custom::ExternalOp;
use crate::ops::LeafOp;
use crate::std_extensions::collections::{EXTENSION, LIST_TYPENAME};
Expand All @@ -370,34 +379,29 @@ mod test {

#[test]
fn op_def_with_type_scheme() -> Result<(), Box<dyn std::error::Error>> {
let reg1 = [PRELUDE.to_owned(), EXTENSION.to_owned()].into();
let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap();
let mut e = Extension::new(EXT_ID);
const TP: TypeParam = TypeParam::Type(TypeBound::Any);
let list_of_var =
Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?);
const OP_NAME: SmolStr = SmolStr::new_inline("Reverse");
let type_scheme = PolyFuncType::new_validated(
vec![TP],
FunctionType::new_endo(vec![list_of_var]),
&reg1,
)?;
let type_scheme = PolyFuncType::new(vec![TP], FunctionType::new_endo(vec![list_of_var]));
e.add_op_type_scheme(OP_NAME, "".into(), Default::default(), vec![], type_scheme)?;
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), EXTENSION.to_owned(), e]).unwrap();
let e = reg.get(&EXT_ID).unwrap();

let list_usize =
Type::new_extension(list_def.instantiate(vec![TypeArg::Type { ty: USIZE_T }])?);
let mut dfg = DFGBuilder::new(FunctionType::new_endo(vec![list_usize]))?;
let rev = dfg.add_dataflow_op(
LeafOp::from(ExternalOp::Extension(
e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: USIZE_T }], &reg1)
e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: USIZE_T }], &reg)
.unwrap(),
)),
dfg.input_wires(),
)?;
dfg.finish_hugr_with_outputs(
rev.outputs(),
&[PRELUDE.to_owned(), EXTENSION.to_owned(), e].into(),
)?;
dfg.finish_hugr_with_outputs(rev.outputs(), &reg)?;

Ok(())
}
Expand Down
3 changes: 2 additions & 1 deletion src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ lazy_static! {
prelude
};
/// An extension registry containing only the prelude
pub static ref PRELUDE_REGISTRY: ExtensionRegistry = [PRELUDE_DEF.to_owned()].into();
pub static ref PRELUDE_REGISTRY: ExtensionRegistry =
ExtensionRegistry::try_new([PRELUDE_DEF.to_owned()]).unwrap();

/// Prelude extension
pub static ref PRELUDE: &'static Extension = PRELUDE_REGISTRY.get(&PRELUDE_ID).unwrap();
Expand Down
4 changes: 3 additions & 1 deletion src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,9 @@ mod test {

#[test]
fn cfg() -> Result<(), Box<dyn std::error::Error>> {
let reg: ExtensionRegistry = [PRELUDE.to_owned(), collections::EXTENSION.to_owned()].into();
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), collections::EXTENSION.to_owned()])
.unwrap();
let listy = Type::new_extension(
collections::EXTENSION
.get_type(collections::LIST_TYPENAME.as_str())
Expand Down
2 changes: 1 addition & 1 deletion src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ fn invalid_types() {
TypeDefBound::Explicit(TypeBound::Any),
)
.unwrap();
let reg: ExtensionRegistry = [e, PRELUDE.to_owned()].into();
let reg = ExtensionRegistry::try_new([e, PRELUDE.to_owned()]).unwrap();

let validate_to_sig_error = |t: CustomType| {
let (h, def) = identity_hugr_with_type(Type::new_extension(t));
Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@
//! lazy_static! {
//! /// Quantum extension definition.
//! pub static ref EXTENSION: Extension = extension();
//! static ref REG: ExtensionRegistry = [EXTENSION.to_owned(), PRELUDE.to_owned()].into();
//! static ref REG: ExtensionRegistry =
//! ExtensionRegistry::try_new([EXTENSION.to_owned(), PRELUDE.to_owned()]).unwrap();
//!
//! }
//! fn get_gate(gate_name: &str) -> LeafOp {
Expand Down
9 changes: 6 additions & 3 deletions src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,10 @@ mod test {
builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr},
extension::{
prelude::{ConstUsize, USIZE_T},
ExtensionId, ExtensionSet,
ExtensionId, ExtensionRegistry, ExtensionSet, PRELUDE,
},
std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE},
std_extensions::arithmetic::float_types::{self, ConstF64, FLOAT64_TYPE},
type_row,
types::test::test_registry,
types::type_param::TypeArg,
types::{CustomCheckFailure, CustomType, FunctionType, Type, TypeBound, TypeRow},
values::{
Expand All @@ -143,6 +142,10 @@ mod test {

use super::*;

fn test_registry() -> ExtensionRegistry {
ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::extension()]).unwrap()
}

#[test]
fn test_tuple_sum() -> Result<(), BuildError> {
use crate::builder::Container;
Expand Down
6 changes: 3 additions & 3 deletions src/ops/leaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ impl DataflowOpTrait for LeafOp {
mod test {
use crate::builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr};
use crate::extension::prelude::BOOL_T;
use crate::extension::SignatureError;
use crate::extension::{prelude::USIZE_T, PRELUDE};
use crate::extension::{ExtensionRegistry, SignatureError};
use crate::hugr::ValidationError;
use crate::ops::handle::NodeHandle;
use crate::std_extensions::collections::EXTENSION;
Expand All @@ -206,7 +206,7 @@ mod test {

#[test]
fn hugr_with_type_apply() -> Result<(), Box<dyn std::error::Error>> {
let reg = [PRELUDE.to_owned(), EXTENSION.to_owned()].into();
let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), EXTENSION.to_owned()]).unwrap();
let pf_in = nested_func();
let pf_out = pf_in.instantiate(&[USIZE_TA], &reg)?;
let mut dfg = DFGBuilder::new(FunctionType::new(
Expand All @@ -225,7 +225,7 @@ mod test {

#[test]
fn bad_type_apply() -> Result<(), Box<dyn std::error::Error>> {
let reg = [PRELUDE.to_owned(), EXTENSION.to_owned()].into();
let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), EXTENSION.to_owned()]).unwrap();
let pf = nested_func();
let pf_usz = pf.instantiate_poly(&[USIZE_TA], &reg)?;
let pf_bool = pf.instantiate_poly(&[TypeArg::Type { ty: BOOL_T }], &reg)?;
Expand Down
49 changes: 17 additions & 32 deletions src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
//! Conversions between integer and floating-point values.
use crate::{
extension::{
prelude::sum_with_error, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError,
PRELUDE,
},
extension::{prelude::sum_with_error, ExtensionId, ExtensionSet},
type_row,
types::{FunctionType, PolyFuncType},
Extension,
Expand All @@ -16,62 +13,50 @@ use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM};
/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");

fn ftoi_sig(temp_reg: &ExtensionRegistry) -> Result<PolyFuncType, SignatureError> {
let body = FunctionType::new(
type_row![FLOAT64_TYPE],
vec![sum_with_error(int_type_var(0))],
/// Extension for basic arithmetic operations.
pub fn extension() -> Extension {
let ftoi_sig = PolyFuncType::new(
vec![LOG_WIDTH_TYPE_PARAM],
FunctionType::new(
type_row![FLOAT64_TYPE],
vec![sum_with_error(int_type_var(0))],
),
);

PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg)
}

fn itof_sig(temp_reg: &ExtensionRegistry) -> Result<PolyFuncType, SignatureError> {
let body = FunctionType::new(vec![int_type_var(0)], type_row![FLOAT64_TYPE]);

PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg)
}
let itof_sig = PolyFuncType::new(
vec![LOG_WIDTH_TYPE_PARAM],
FunctionType::new(vec![int_type_var(0)], type_row![FLOAT64_TYPE]),
);

/// Extension for basic arithmetic operations.
pub fn extension() -> Extension {
let mut extension = Extension::new_with_reqs(
EXTENSION_ID,
ExtensionSet::from_iter(vec![
super::int_types::EXTENSION_ID,
super::float_types::EXTENSION_ID,
]),
);
let temp_reg: ExtensionRegistry = [
super::int_types::EXTENSION.to_owned(),
super::float_types::extension(),
PRELUDE.to_owned(),
]
.into();
extension
.add_op_type_scheme_simple(
"trunc_u".into(),
"float to unsigned int".to_owned(),
ftoi_sig(&temp_reg).unwrap(),
ftoi_sig.clone(),
)
.unwrap();
extension
.add_op_type_scheme_simple(
"trunc_s".into(),
"float to signed int".to_owned(),
ftoi_sig(&temp_reg).unwrap(),
)
.add_op_type_scheme_simple("trunc_s".into(), "float to signed int".to_owned(), ftoi_sig)
.unwrap();
extension
.add_op_type_scheme_simple(
"convert_u".into(),
"unsigned int to float".to_owned(),
itof_sig(&temp_reg).unwrap(),
itof_sig.clone(),
)
.unwrap();
extension
.add_op_type_scheme_simple(
"convert_s".into(),
"signed int to float".to_owned(),
itof_sig(&temp_reg).unwrap(),
itof_sig,
)
.unwrap();

Expand Down
Loading

0 comments on commit 4a8d190

Please sign in to comment.