From 344a7e49658895b6358bdd45f3fe2867f27ec9ae Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 27 Nov 2024 09:40:42 +0000 Subject: [PATCH 1/3] feat: Lists and extension sets with splicing (#1657) This PR allows lists and extension sets in `hugr-model` to splice lists and extension sets, e.g. `[0 xs ... 1 2 3]`. This is used to import and export rows and extension sets with variables. Closes #1609. --- hugr-core/src/export.rs | 167 ++++++++---------- hugr-core/src/import.rs | 145 +++++++++------ hugr-core/src/types.rs | 2 +- .../snapshots/model__roundtrip_call.snap | 2 +- .../tests/snapshots/model__roundtrip_cfg.snap | 6 +- hugr-model/capnp/hugr-v0.capnp | 20 ++- hugr-model/src/v0/binary/read.rs | 41 +++-- hugr-model/src/v0/binary/write.rs | 27 ++- hugr-model/src/v0/mod.rs | 44 +++-- hugr-model/src/v0/text/hugr.pest | 8 +- hugr-model/src/v0/text/parse.rs | 53 +++--- hugr-model/src/v0/text/print.rs | 73 +++++--- hugr-model/tests/binary.rs | 5 + hugr-model/tests/fixtures/model-call.edn | 2 +- hugr-model/tests/fixtures/model-lists.edn | 21 +++ 15 files changed, 367 insertions(+), 249 deletions(-) create mode 100644 hugr-model/tests/fixtures/model-lists.edn diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 093368b60..f433a3482 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -509,48 +509,24 @@ impl<'a> Context<'a> { /// like for the other nodes since the ports are control flow ports. pub fn export_block_signature(&mut self, block: &DataflowBlock) -> model::TermId { let inputs = { - let mut inputs = BumpVec::with_capacity_in(block.inputs.len(), self.bump); - for input in block.inputs.iter() { - inputs.push(self.export_type(input)); - } - let inputs = self.make_term(model::Term::List { - items: inputs.into_bump_slice(), - tail: None, - }); + let inputs = self.export_type_row(&block.inputs); let inputs = self.make_term(model::Term::Control { values: inputs }); self.make_term(model::Term::List { - items: self.bump.alloc_slice_copy(&[inputs]), - tail: None, + parts: self.bump.alloc_slice_copy(&[model::ListPart::Item(inputs)]), }) }; - let tail = { - let mut tail = BumpVec::with_capacity_in(block.other_outputs.len(), self.bump); - for other_output in block.other_outputs.iter() { - tail.push(self.export_type(other_output)); - } - self.make_term(model::Term::List { - items: tail.into_bump_slice(), - tail: None, - }) - }; + let tail = self.export_type_row(&block.other_outputs); let outputs = { let mut outputs = BumpVec::with_capacity_in(block.sum_rows.len(), self.bump); for sum_row in block.sum_rows.iter() { - let mut variant = BumpVec::with_capacity_in(sum_row.len(), self.bump); - for typ in sum_row.iter() { - variant.push(self.export_type(typ)); - } - let variant = self.make_term(model::Term::List { - items: variant.into_bump_slice(), - tail: Some(tail), - }); - outputs.push(self.make_term(model::Term::Control { values: variant })); + let variant = self.export_type_row_with_tail(sum_row, Some(tail)); + let control = self.make_term(model::Term::Control { values: variant }); + outputs.push(model::ListPart::Item(control)); } self.make_term(model::Term::List { - items: outputs.into_bump_slice(), - tail: None, + parts: outputs.into_bump_slice(), }) }; @@ -772,10 +748,12 @@ impl<'a> Context<'a> { TypeArg::String { arg } => self.make_term(model::Term::Str(self.bump.alloc_str(arg))), TypeArg::Sequence { elems } => { // For now we assume that the sequence is meant to be a list. - let items = self - .bump - .alloc_slice_fill_iter(elems.iter().map(|elem| self.export_type_arg(elem))); - self.make_term(model::Term::List { items, tail: None }) + let parts = self.bump.alloc_slice_fill_iter( + elems + .iter() + .map(|elem| model::ListPart::Item(self.export_type_arg(elem))), + ); + self.make_term(model::Term::List { parts }) } TypeArg::Extensions { es } => self.export_ext_set(es), TypeArg::Variable { v } => self.export_type_arg_var(v), @@ -798,32 +776,53 @@ impl<'a> Context<'a> { pub fn export_sum_type(&mut self, t: &SumType) -> model::TermId { match t { SumType::Unit { size } => { - let items = self.bump.alloc_slice_fill_iter((0..*size).map(|_| { - self.make_term(model::Term::List { - items: &[], - tail: None, - }) + let parts = self.bump.alloc_slice_fill_iter((0..*size).map(|_| { + model::ListPart::Item(self.make_term(model::Term::List { parts: &[] })) })); - let list = model::Term::List { items, tail: None }; - let variants = self.make_term(list); + let variants = self.make_term(model::Term::List { parts }); self.make_term(model::Term::Adt { variants }) } SumType::General { rows } => { - let items = self - .bump - .alloc_slice_fill_iter(rows.iter().map(|row| self.export_type_row(row))); - let list = model::Term::List { items, tail: None }; + let parts = self.bump.alloc_slice_fill_iter( + rows.iter() + .map(|row| model::ListPart::Item(self.export_type_row(row))), + ); + let list = model::Term::List { parts }; let variants = { self.make_term(list) }; self.make_term(model::Term::Adt { variants }) } } } - pub fn export_type_row(&mut self, t: &TypeRowBase) -> model::TermId { - let mut items = BumpVec::with_capacity_in(t.len(), self.bump); - items.extend(t.iter().map(|row| self.export_type(row))); - let items = items.into_bump_slice(); - self.make_term(model::Term::List { items, tail: None }) + #[inline] + pub fn export_type_row(&mut self, row: &TypeRowBase) -> model::TermId { + self.export_type_row_with_tail(row, None) + } + + pub fn export_type_row_with_tail( + &mut self, + row: &TypeRowBase, + tail: Option, + ) -> model::TermId { + let mut parts = BumpVec::with_capacity_in(row.len() + tail.is_some() as usize, self.bump); + + for t in row.iter() { + match t.as_type_enum() { + TypeEnum::RowVar(var) => { + parts.push(model::ListPart::Splice(self.export_row_var(var.as_rv()))); + } + _ => { + parts.push(model::ListPart::Item(self.export_type(t))); + } + } + } + + if let Some(tail) = tail { + parts.push(model::ListPart::Splice(tail)); + } + + let parts = parts.into_bump_slice(); + self.make_term(model::Term::List { parts }) } /// Exports a `TypeParam` to a term. @@ -855,12 +854,12 @@ impl<'a> Context<'a> { self.make_term(model::Term::ListType { item_type }) } TypeParam::Tuple { params } => { - let items = self.bump.alloc_slice_fill_iter( + let parts = self.bump.alloc_slice_fill_iter( params .iter() - .map(|param| self.export_type_param(param, None)), + .map(|param| model::ListPart::Item(self.export_type_param(param, None))), ); - let types = self.make_term(model::Term::List { items, tail: None }); + let types = self.make_term(model::Term::List { parts }); self.make_term(model::Term::ApplyFull { global: model::GlobalRef::Named(TERM_PARAM_TUPLE), args: self.bump.alloc_slice_copy(&[types]), @@ -873,54 +872,26 @@ impl<'a> Context<'a> { } } - pub fn export_ext_set(&mut self, t: &ExtensionSet) -> model::TermId { - // Extension sets with variables are encoded using a hack: a variable in the - // extension set is represented by converting its index into a string. - // Until we have a better representation for extension sets, we therefore - // need to try and parse each extension as a number to determine if it is - // a variable or an extension. - - // NOTE: This overprovisions the capacity since some of the entries of the row - // may be variables. Since we panic when there is more than one variable, this - // may at most waste one slot. That is way better than having to allocate - // a temporary vector. - // - // Also `ExtensionSet` has no way of reporting its size, so we have to count - // the elements by iterating over them... - let capacity = t.iter().count(); - let mut extensions = BumpVec::with_capacity_in(capacity, self.bump); - let mut rest = None; - - for ext in t.iter() { - if let Ok(index) = ext.parse::() { - // Extension sets in the model support at most one variable. This is a - // deliberate limitation so that extension sets behave like polymorphic rows. - // The type theory of such rows and how to apply them to model (co)effects - // is well understood. - // - // Extension sets in `hugr-core` at this point have no such restriction. - // However, it appears that so far we never actually use extension sets with - // multiple variables, except for extension sets that are generated through - // property testing. - if rest.is_some() { - // TODO: We won't need this anymore once we have a core representation - // that ensures that extension sets have at most one variable. - panic!("Extension set with multiple variables") - } + pub fn export_ext_set(&mut self, ext_set: &ExtensionSet) -> model::TermId { + let capacity = ext_set.iter().size_hint().0; + let mut parts = BumpVec::with_capacity_in(capacity, self.bump); - let node = self.local_scope.expect("local variable out of scope"); - rest = Some( - self.module - .insert_term(model::Term::Var(model::LocalRef::Index(node, index as _))), - ); - } else { - extensions.push(self.bump.alloc_str(ext) as &str); + for ext in ext_set.iter() { + // `ExtensionSet`s represent variables by extension names that parse to integers. + match ext.parse::() { + Ok(var) => { + let node = self.local_scope.expect("local variable out of scope"); + let local_ref = model::LocalRef::Index(node, var); + let term = self.make_term(model::Term::Var(local_ref)); + parts.push(model::ExtSetPart::Splice(term)); + } + Err(_) => parts.push(model::ExtSetPart::Extension(self.bump.alloc_str(ext))), } } - let extensions = extensions.into_bump_slice(); - - self.make_term(model::Term::ExtSet { extensions, rest }) + self.make_term(model::Term::ExtSet { + parts: parts.into_bump_slice(), + }) } pub fn export_node_metadata( diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 7619ad44a..f3009cc17 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -13,9 +13,9 @@ use crate::{ CFG, DFG, }, types::{ - type_param::TypeParam, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV, NoRV, + type_param::TypeParam, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV, PolyFuncType, PolyFuncTypeBase, RowVariable, Signature, Type, TypeArg, TypeBase, TypeBound, - TypeRow, + TypeEnum, TypeRow, }, Direction, Hugr, HugrView, Node, Port, }; @@ -1038,32 +1038,39 @@ impl<'a> Context<'a> { &mut self, term_id: model::TermId, ) -> Result { - match self.get_term(term_id)? { - model::Term::Wildcard => Err(error_uninferred!("wildcard")), + let mut es = ExtensionSet::new(); + let mut stack = vec![term_id]; - model::Term::Var(var) => { - let mut es = ExtensionSet::new(); - let (index, _) = self.resolve_local_ref(var)?; - es.insert_type_var(index); - Ok(es) - } - - model::Term::ExtSet { extensions, rest } => { - let mut es = match rest { - Some(rest) => self.import_extension_set(*rest)?, - None => ExtensionSet::new(), - }; + while let Some(term_id) = stack.pop() { + match self.get_term(term_id)? { + model::Term::Wildcard => return Err(error_uninferred!("wildcard")), - for ext in extensions.iter() { - let ext_ident = IdentList::new(*ext) - .map_err(|_| model::ModelError::MalformedName(ext.to_smolstr()))?; - es.insert(&ext_ident); + model::Term::Var(var) => { + let (index, _) = self.resolve_local_ref(var)?; + es.insert_type_var(index); } - Ok(es) + model::Term::ExtSet { parts } => { + for part in *parts { + match part { + model::ExtSetPart::Extension(ext) => { + let ext_ident = IdentList::new(*ext).map_err(|_| { + model::ModelError::MalformedName(ext.to_smolstr()) + })?; + es.insert(&ext_ident); + } + model::ExtSetPart::Splice(term_id) => { + // The order in an extension set does not matter. + stack.push(*term_id); + } + } + } + } + _ => return Err(model::ModelError::TypeError(term_id).into()), } - _ => Err(model::ModelError::TypeError(term_id).into()), } + + Ok(es) } /// Import a `Type` from a term that represents a runtime type. @@ -1103,7 +1110,7 @@ impl<'a> Context<'a> { } model::Term::FuncType { .. } => { - let func_type = self.import_func_type::(term_id)?; + let func_type = self.import_func_type::(term_id)?; Ok(TypeBase::new_function(func_type)) } @@ -1157,39 +1164,45 @@ impl<'a> Context<'a> { term_id: model::TermId, ) -> Result, ImportError> { let (inputs, outputs, extensions) = self.get_func_type(term_id)?; - let inputs = self.import_type_row::(inputs)?; - let outputs = self.import_type_row::(outputs)?; + let inputs = self.import_type_row(inputs)?; + let outputs = self.import_type_row(outputs)?; let extensions = self.import_extension_set(extensions)?; Ok(FuncTypeBase::new(inputs, outputs).with_extension_delta(extensions)) } fn import_closed_list( &mut self, - mut term_id: model::TermId, + term_id: model::TermId, ) -> Result, ImportError> { - // PERFORMANCE: We currently allocate a Vec here to collect list items - // into, in order to handle the case where the tail of the list is another - // list. We should avoid this. - let mut list_items = Vec::new(); - - loop { - match self.get_term(term_id)? { - model::Term::Var(_) => return Err(error_unsupported!("open lists")), - model::Term::List { items, tail } => { - list_items.extend(items.iter()); - - match tail { - Some(tail) => term_id = *tail, - None => break, + fn import_into( + ctx: &mut Context, + term_id: model::TermId, + types: &mut Vec, + ) -> Result<(), ImportError> { + match ctx.get_term(term_id)? { + model::Term::List { parts } => { + types.reserve(parts.len()); + + for part in *parts { + match part { + model::ListPart::Item(term_id) => { + types.push(*term_id); + } + model::ListPart::Splice(term_id) => { + import_into(ctx, *term_id, types)?; + } + } } } - _ => { - return Err(model::ModelError::TypeError(term_id).into()); - } + _ => return Err(model::ModelError::TypeError(term_id).into()), } + + Ok(()) } - Ok(list_items) + let mut types = Vec::new(); + import_into(self, term_id, &mut types)?; + Ok(types) } fn import_type_rows( @@ -1197,8 +1210,8 @@ impl<'a> Context<'a> { term_id: model::TermId, ) -> Result>, ImportError> { self.import_closed_list(term_id)? - .iter() - .map(|row| self.import_type_row::(*row)) + .into_iter() + .map(|term_id| self.import_type_row::(term_id)) .collect() } @@ -1206,13 +1219,41 @@ impl<'a> Context<'a> { &mut self, term_id: model::TermId, ) -> Result, ImportError> { - let items = self - .import_closed_list(term_id)? - .iter() - .map(|item| self.import_type(*item)) - .collect::, _>>()?; + fn import_into( + ctx: &mut Context, + term_id: model::TermId, + types: &mut Vec>, + ) -> Result<(), ImportError> { + match ctx.get_term(term_id)? { + model::Term::List { parts } => { + types.reserve(parts.len()); + + for item in *parts { + match item { + model::ListPart::Item(term_id) => { + types.push(ctx.import_type::(*term_id)?); + } + model::ListPart::Splice(term_id) => { + import_into(ctx, *term_id, types)?; + } + } + } + } + model::Term::Var(var) => { + let (index, _) = ctx.resolve_local_ref(var)?; + let var = RV::try_from_rv(RowVariable(index, TypeBound::Any)) + .map_err(|_| model::ModelError::TypeError(term_id))?; + types.push(TypeBase::new(TypeEnum::RowVar(var))); + } + _ => return Err(model::ModelError::TypeError(term_id).into()), + } + + Ok(()) + } - Ok(items.into()) + let mut types = Vec::new(); + import_into(self, term_id, &mut types)?; + Ok(types.into()) } fn import_custom_name( diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 72e752614..c61149bff 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -364,7 +364,7 @@ impl TypeBase { Self::new(TypeEnum::Alias(alias)) } - fn new(type_e: TypeEnum) -> Self { + pub(crate) fn new(type_e: TypeEnum) -> Self { let bound = type_e.least_upper_bound(); Self(type_e, bound) } diff --git a/hugr-core/tests/snapshots/model__roundtrip_call.snap b/hugr-core/tests/snapshots/model__roundtrip_call.snap index 460d8f4c0..5ddc4eb32 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_call.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_call.snap @@ -8,7 +8,7 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call (forall ?0 ext-set) [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] - (ext arithmetic.int . ?0) + (ext ?0 ... arithmetic.int) (meta doc.description (@ prelude.json "\"This is a function declaration.\"")) (meta doc.title (@ prelude.json "\"Callee\""))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap index f3c0f0acc..41a8f0d62 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"fixtures/model-cfg.edn\"))" +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg.edn\"))" --- (hugr 0) @@ -16,13 +16,13 @@ expression: "roundtrip(include_str!(\"fixtures/model-cfg.edn\"))" [%2] [%8] (signature (fn [?0] [?0] (ext))) (block [%2] [%5] - (signature (fn [(ctrl [?0])] [(ctrl [?0 . []])] (ext))) + (signature (fn [(ctrl [?0])] [(ctrl [?0])] (ext))) (dfg [%3] [%4] (signature (fn [?0] [(adt [[?0]])] (ext))) (tag 0 [%3] [%4] (signature (fn [?0] [(adt [[?0]])] (ext)))))) (block [%5] [%8] - (signature (fn [(ctrl [?0])] [(ctrl [?0 . []])] (ext))) + (signature (fn [(ctrl [?0])] [(ctrl [?0])] (ext))) (dfg [%6] [%7] (signature (fn [?0] [(adt [[?0]])] (ext))) diff --git a/hugr-model/capnp/hugr-v0.capnp b/hugr-model/capnp/hugr-v0.capnp index 94341beba..366de92eb 100644 --- a/hugr-model/capnp/hugr-v0.capnp +++ b/hugr-model/capnp/hugr-v0.capnp @@ -175,13 +175,25 @@ struct Term { } struct ListTerm { - items @0 :List(TermId); - tail @1 :OptionalTermId; + items @0 :List(ListPart); + } + + struct ListPart { + union { + item @0 :TermId; + splice @1 :TermId; + } } struct ExtSet { - extensions @0 :List(Text); - rest @1 :OptionalTermId; + items @0 :List(ExtSetPart); + } + + struct ExtSetPart { + union { + extension @0 :Text; + splice @1 :TermId; + } } struct FuncType { diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index 5381a7dc8..2dfe67efc 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -296,9 +296,8 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult Which::List(reader) => { let reader = reader?; - let items = read_scalar_list!(bump, reader, get_items, model::TermId); - let tail = reader.get_tail().checked_sub(1).map(model::TermId); - model::Term::List { items, tail } + let parts = read_list!(bump, reader, get_items, read_list_part); + model::Term::List { parts } } Which::ListType(item_type) => model::Term::ListType { @@ -307,18 +306,8 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult Which::ExtSet(reader) => { let reader = reader?; - - let extensions = { - let extensions_reader = reader.get_extensions()?; - let mut extensions = BumpVec::with_capacity_in(extensions_reader.len() as _, bump); - for extension_reader in extensions_reader.iter() { - extensions.push(bump.alloc_str(extension_reader?.to_str()?) as &str); - } - extensions.into_bump_slice() - }; - - let rest = reader.get_rest().checked_sub(1).map(model::TermId); - model::Term::ExtSet { extensions, rest } + let parts = read_list!(bump, reader, get_items, read_ext_set_part); + model::Term::ExtSet { parts } } Which::Adt(variants) => model::Term::Adt { @@ -356,6 +345,28 @@ fn read_meta_item<'a>( Ok(model::MetaItem { name, value }) } +fn read_list_part( + _: &Bump, + reader: hugr_capnp::term::list_part::Reader, +) -> ReadResult { + use hugr_capnp::term::list_part::Which; + Ok(match reader.which()? { + Which::Item(term) => model::ListPart::Item(model::TermId(term)), + Which::Splice(list) => model::ListPart::Splice(model::TermId(list)), + }) +} + +fn read_ext_set_part<'a>( + bump: &'a Bump, + reader: hugr_capnp::term::ext_set_part::Reader, +) -> ReadResult> { + use hugr_capnp::term::ext_set_part::Which; + Ok(match reader.which()? { + Which::Extension(ext) => model::ExtSetPart::Extension(bump.alloc_str(ext?.to_str()?)), + Which::Splice(list) => model::ExtSetPart::Splice(model::TermId(list)), + }) +} + fn read_param<'a>( bump: &'a Bump, reader: hugr_capnp::param::Reader, diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index f3a0a14d2..aa377e2ec 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -187,16 +187,14 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { let _ = builder.set_args(model::TermId::unwrap_slice(args)); } - model::Term::List { items, tail } => { + model::Term::List { parts } => { let mut builder = builder.init_list(); - let _ = builder.set_items(model::TermId::unwrap_slice(items)); - builder.set_tail(tail.map_or(0, |t| t.0 + 1)); + write_list!(builder, init_items, write_list_item, parts); } - model::Term::ExtSet { extensions, rest } => { + model::Term::ExtSet { parts } => { let mut builder = builder.init_ext_set(); - let _ = builder.set_extensions(*extensions); - builder.set_rest(rest.map_or(0, |t| t.0 + 1)); + write_list!(builder, init_items, write_ext_set_item, parts); } model::Term::FuncType { @@ -215,3 +213,20 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { } } } + +fn write_list_item(mut builder: hugr_capnp::term::list_part::Builder, item: &model::ListPart) { + match item { + model::ListPart::Item(term_id) => builder.set_item(term_id.0), + model::ListPart::Splice(term_id) => builder.set_splice(term_id.0), + } +} + +fn write_ext_set_item( + mut builder: hugr_capnp::term::ext_set_part::Builder, + item: &model::ExtSetPart, +) { + match item { + model::ExtSetPart::Extension(ext) => builder.set_extension(ext), + model::ExtSetPart::Splice(term_id) => builder.set_splice(term_id.0), + } +} diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index 16c7cb6c6..2b0dc1eaf 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -569,19 +569,10 @@ pub enum Term<'a> { r#type: TermId, }, - /// A list, with an optional tail. - /// - /// - `[ITEM-0 ... ITEM-n] : (list T)` where `T : static`, `ITEM-i : T`. - /// - `[ITEM-0 ... ITEM-n . TAIL] : (list item-type)` where `T : static`, `ITEM-i : T`, `TAIL : (list T)`. + /// A list. May include individual items or other lists to be spliced in. List { - /// The items in the list. - /// - /// `item-i : item-type` - items: &'a [TermId], - /// The tail of the list. - /// - /// `tail : (list item-type)` - tail: Option, + /// The parts of the list. + parts: &'a [ListPart], }, /// The type of lists, given a type for the items. @@ -615,14 +606,11 @@ pub enum Term<'a> { NatType, /// Extension set. - /// - /// - `(ext EXT-0 ... EXT-n) : ext-set` - /// - `(ext EXT-0 ... EXT-n . REST) : ext-set` where `REST : ext-set`. ExtSet { - /// The items in the extension set. - extensions: &'a [&'a str], - /// The rest of the extension set. - rest: Option, + /// The parts of the extension set. + /// + /// Since extension sets are unordered, the parts may occur in any order. + parts: &'a [ExtSetPart<'a>], }, /// The type of extension sets. @@ -676,6 +664,24 @@ pub enum Term<'a> { }, } +/// A part of a list term. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ListPart { + /// A single item. + Item(TermId), + /// A list to be spliced into the parent list. + Splice(TermId), +} + +/// A part of an extension set term. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ExtSetPart<'a> { + /// An extension. + Extension(&'a str), + /// An extension set to be spliced into the parent extension set. + Splice(TermId), +} + /// A parameter to a function or alias. /// /// Parameter names must be unique within a parameter list. diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index d05e3d774..fc52b8271 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -10,8 +10,6 @@ string_raw = @{ (!("\\" | "\"") ~ ANY)+ } string_escape = @{ "\\" ~ ("\"" | "\\" | "n" | "r" | "t") } string_unicode = @{ "\\u" ~ "{" ~ ASCII_HEX_DIGIT+ ~ "}" } -list_tail = { "." } - module = { "(" ~ "hugr" ~ "0" ~ ")" ~ meta* ~ node* ~ EOI } meta = { "(" ~ "meta" ~ symbol ~ term ~ ")" } @@ -103,16 +101,18 @@ term_var = { "?" ~ identifier } term_apply_full = { ("(" ~ "@" ~ symbol ~ term* ~ ")") } term_apply = { symbol | ("(" ~ symbol ~ term* ~ ")") } term_quote = { "(" ~ "quote" ~ term ~ ")" } -term_list = { "[" ~ term* ~ (list_tail ~ term)? ~ "]" } +term_list = { "[" ~ (spliced_term | term)* ~ "]" } term_list_type = { "(" ~ "list" ~ term ~ ")" } term_str = { string } term_str_type = { "str" } term_nat = { (ASCII_DIGIT)+ } term_nat_type = { "nat" } -term_ext_set = { "(" ~ "ext" ~ ext_name* ~ (list_tail ~ term)? ~ ")" } +term_ext_set = { "(" ~ "ext" ~ (spliced_term | ext_name)* ~ ")" } term_ext_set_type = { "ext-set" } term_adt = { "(" ~ "adt" ~ term ~ ")" } term_func_type = { "(" ~ "fn" ~ term ~ term ~ term ~ ")" } term_ctrl = { "(" ~ "ctrl" ~ term ~ ")" } term_ctrl_type = { "ctrl" } term_non_linear = { "(" ~ "nonlinear" ~ term ~ ")" } + +spliced_term = { term ~ "..." } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index 370dbeac0..8527f1a00 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -1,4 +1,4 @@ -use bumpalo::{collections::String as BumpString, Bump}; +use bumpalo::{collections::String as BumpString, collections::Vec as BumpVec, Bump}; use pest::{ iterators::{Pair, Pairs}, Parser, RuleType, @@ -6,8 +6,9 @@ use pest::{ use thiserror::Error; use crate::v0::{ - AliasDecl, ConstructorDecl, FuncDecl, GlobalRef, LinkRef, LocalRef, MetaItem, Module, Node, - NodeId, Operation, OperationDecl, Param, ParamSort, Region, RegionId, RegionKind, Term, TermId, + AliasDecl, ConstructorDecl, ExtSetPart, FuncDecl, GlobalRef, LinkRef, ListPart, LocalRef, + MetaItem, Module, Node, NodeId, Operation, OperationDecl, Param, ParamSort, Region, RegionId, + RegionKind, Term, TermId, }; mod pest_parser { @@ -136,21 +137,21 @@ impl<'a> ParseContext<'a> { } Rule::term_list => { - let mut items = Vec::new(); - let mut tail = None; + let mut parts = BumpVec::with_capacity_in(inner.len(), self.bump); - for token in filter_rule(&mut inner, Rule::term) { - items.push(self.parse_term(token)?); - } - - if inner.next().is_some() { - let token = inner.next().unwrap(); - tail = Some(self.parse_term(token)?); + for token in inner { + match token.as_rule() { + Rule::term => parts.push(ListPart::Item(self.parse_term(token)?)), + Rule::spliced_term => { + let term_token = token.into_inner().next().unwrap(); + parts.push(ListPart::Splice(self.parse_term(term_token)?)) + } + _ => unreachable!(), + } } Term::List { - items: self.bump.alloc_slice_copy(&items), - tail, + parts: parts.into_bump_slice(), } } @@ -170,21 +171,23 @@ impl<'a> ParseContext<'a> { } Rule::term_ext_set => { - let mut extensions = Vec::new(); - let mut rest = None; + let mut parts = BumpVec::with_capacity_in(inner.len(), self.bump); - for token in filter_rule(&mut inner, Rule::ext_name) { - extensions.push(token.as_str()); - } - - if inner.next().is_some() { - let token = inner.next().unwrap(); - rest = Some(self.parse_term(token)?); + for token in inner { + match token.as_rule() { + Rule::ext_name => { + parts.push(ExtSetPart::Extension(self.bump.alloc_str(token.as_str()))) + } + Rule::spliced_term => { + let term_token = token.into_inner().next().unwrap(); + parts.push(ExtSetPart::Splice(self.parse_term(term_token)?)) + } + _ => unreachable!(), + } } Term::ExtSet { - extensions: self.bump.alloc_slice_copy(&extensions), - rest, + parts: parts.into_bump_slice(), } } diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index a320d4664..5c6c18f7a 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -2,8 +2,8 @@ use pretty::{Arena, DocAllocator, RefDoc}; use std::borrow::Cow; use crate::v0::{ - GlobalRef, LinkRef, LocalRef, MetaItem, ModelError, Module, NodeId, Operation, Param, - ParamSort, RegionId, RegionKind, Term, TermId, + ExtSetPart, GlobalRef, LinkRef, ListPart, LocalRef, MetaItem, ModelError, Module, NodeId, + Operation, Param, ParamSort, RegionId, RegionKind, Term, TermId, }; type PrintError = ModelError; @@ -521,16 +521,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text("quote"); this.print_term(*r#type) }), - Term::List { items, tail } => self.print_brackets(|this| { - for item in items.iter() { - this.print_term(*item)?; - } - if let Some(tail) = tail { - this.print_text("."); - this.print_term(*tail)?; - } - Ok(()) - }), + Term::List { .. } => self.print_brackets(|this| this.print_list_parts(term_id)), Term::ListType { item_type } => self.print_parens(|this| { this.print_text("list"); this.print_term(*item_type) @@ -551,15 +542,9 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { self.print_text("nat"); Ok(()) } - Term::ExtSet { extensions, rest } => self.print_parens(|this| { + Term::ExtSet { .. } => self.print_parens(|this| { this.print_text("ext"); - for extension in *extensions { - this.print_text(*extension); - } - if let Some(rest) = rest { - this.print_text("."); - this.print_term(*rest)?; - } + this.print_ext_set_parts(term_id)?; Ok(()) }), Term::ExtSetType => { @@ -595,6 +580,54 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { } } + /// Prints the contents of a list. + /// + /// This is used so that spliced lists are merged into the parent list. + fn print_list_parts(&mut self, term_id: TermId) -> PrintResult<()> { + let term_data = self + .module + .get_term(term_id) + .ok_or_else(|| PrintError::TermNotFound(term_id))?; + + if let Term::List { parts } = term_data { + for part in *parts { + match part { + ListPart::Item(term) => self.print_term(*term)?, + ListPart::Splice(list) => self.print_list_parts(*list)?, + } + } + } else { + self.print_term(term_id)?; + self.print_text("..."); + } + + Ok(()) + } + + /// Prints the contents of an extension set. + /// + /// This is used so that spliced extension sets are merged into the parent extension set. + fn print_ext_set_parts(&mut self, term_id: TermId) -> PrintResult<()> { + let term_data = self + .module + .get_term(term_id) + .ok_or_else(|| PrintError::TermNotFound(term_id))?; + + if let Term::ExtSet { parts } = term_data { + for part in *parts { + match part { + ExtSetPart::Extension(ext) => self.print_text(*ext), + ExtSetPart::Splice(list) => self.print_ext_set_parts(*list)?, + } + } + } else { + self.print_term(term_id)?; + self.print_text("..."); + } + + Ok(()) + } + fn print_local_ref(&mut self, local_ref: LocalRef<'a>) -> PrintResult<()> { let name = match local_ref { LocalRef::Index(_, i) => { diff --git a/hugr-model/tests/binary.rs b/hugr-model/tests/binary.rs index 17609c9e4..80157c23e 100644 --- a/hugr-model/tests/binary.rs +++ b/hugr-model/tests/binary.rs @@ -58,3 +58,8 @@ pub fn test_decl_exts() { pub fn test_constraints() { binary_roundtrip(include_str!("fixtures/model-constraints.edn")); } + +#[test] +pub fn test_lists() { + binary_roundtrip(include_str!("fixtures/model-lists.edn")); +} diff --git a/hugr-model/tests/fixtures/model-call.edn b/hugr-model/tests/fixtures/model-call.edn index ce849a772..87c6f7a3a 100644 --- a/hugr-model/tests/fixtures/model-call.edn +++ b/hugr-model/tests/fixtures/model-call.edn @@ -2,7 +2,7 @@ (declare-func example.callee (forall ?ext ext-set) - [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int . ?ext) + [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int ?ext ...) (meta doc.title (prelude.json "\"Callee\"")) (meta doc.description (prelude.json "\"This is a function declaration.\""))) diff --git a/hugr-model/tests/fixtures/model-lists.edn b/hugr-model/tests/fixtures/model-lists.edn new file mode 100644 index 000000000..1385a0e2a --- /dev/null +++ b/hugr-model/tests/fixtures/model-lists.edn @@ -0,0 +1,21 @@ +(hugr 0) + +(declare-operation core.call-indirect + (forall ?inputs (list type)) + (forall ?outputs (list type)) + (forall ?exts ext-set) + (fn [(fn ?inputs ?outputs ?exts) ?inputs ...] ?outputs ?exts)) + +(declare-operation core.compose-parallel + (forall ?inputs-0 (list type)) + (forall ?inputs-1 (list type)) + (forall ?outputs-0 (list type)) + (forall ?outputs-1 (list type)) + (forall ?exts ext-set) + (fn + [(fn ?inputs-0 ?outputs-0 ?exts) + (fn ?inputs-1 ?outputs-1 ?exts) + ?inputs-0 ... + ?inputs-1 ...] + [?outputs-0 ... ?outputs-1 ...] + ?exts)) From 517fd3da00d4a2427fdb2f0d3ffdfcae25cbfd94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Wed, 27 Nov 2024 11:12:46 +0000 Subject: [PATCH 2/3] feat!: OpDefs and TypeDefs keep a reference to their extension (#1719) This change was extracted from the work towards #1613. Now `OpDef`s and `TypeDef`s keep a `Weak` reference to their extension's `Arc`. This way we will be able to automatically set the extension requirements when adding operations, so we can get rid of `update_validate` and the explicit registries when building hugrs. To implement this, the building interface for `Extension`s is sightly modified. Once an `Arc` is built it cannot be modified without doing internal mutation. But we need the `Arc`'s weak reference to define the ops and types. Thankfully, we can use `Arc::new_cyclic` which provides us with a `Weak` ref at build time so we are able to define things as needed. This is wrapped in a new `Extension::new_arc` method, so the user doesn't need to think about that. BREAKING CHANGE: Renamed `OpDef::extension` and `TypeDef::extension` to `extension_id`. `extension` now returns weak references to the `Extension` defining them. BREAKING CHANGE: `Extension::with_reqs` moved to `set_reqs`, which takes `&mut self` instead of `self`. BREAKING CHANGE: `Extension::add_type` and `Extension::add_op` now take an extra parameter. See docs for example usage. BREAKING CHANGE: `ExtensionRegistry::register_updated` and `register_updated_ref` are no longer fallible. --- hugr-cli/src/validate.rs | 4 +- hugr-core/src/builder/circuit.rs | 16 +- hugr-core/src/export.rs | 6 +- hugr-core/src/extension.rs | 153 +++++++-- hugr-core/src/extension/declarative.rs | 32 +- hugr-core/src/extension/declarative/ops.rs | 12 +- hugr-core/src/extension/declarative/types.rs | 7 + hugr-core/src/extension/op_def.rs | 305 ++++++++++-------- hugr-core/src/extension/prelude.rs | 149 ++++----- hugr-core/src/extension/prelude/array.rs | 11 +- hugr-core/src/extension/simple_op.rs | 36 ++- hugr-core/src/extension/type_def.rs | 47 ++- hugr-core/src/hugr/rewrite/replace.rs | 36 ++- hugr-core/src/hugr/validate/test.rs | 128 ++++---- hugr-core/src/ops/custom.rs | 42 +-- hugr-core/src/package.rs | 2 +- .../std_extensions/arithmetic/conversions.rs | 23 +- .../std_extensions/arithmetic/float_ops.rs | 15 +- .../std_extensions/arithmetic/float_types.rs | 23 +- .../src/std_extensions/arithmetic/int_ops.rs | 15 +- .../std_extensions/arithmetic/int_types.rs | 23 +- hugr-core/src/std_extensions/collections.rs | 39 +-- hugr-core/src/std_extensions/logic.rs | 20 +- hugr-core/src/std_extensions/ptr.rs | 25 +- hugr-core/src/types/poly_func.rs | 20 +- hugr-core/src/utils.rs | 96 +++--- hugr-llvm/src/custom/extension_op.rs | 2 +- hugr-passes/src/merge_bbs.rs | 42 +-- hugr/benches/benchmarks/hugr/examples.rs | 56 ++-- hugr/src/lib.rs | 20 +- 30 files changed, 837 insertions(+), 568 deletions(-) diff --git a/hugr-cli/src/validate.rs b/hugr-cli/src/validate.rs index 996799b77..c2db97539 100644 --- a/hugr-cli/src/validate.rs +++ b/hugr-cli/src/validate.rs @@ -2,7 +2,6 @@ use clap::Parser; use clap_verbosity_flag::Level; -use hugr::package::PackageValidationError; use hugr::{extension::ExtensionRegistry, Extension, Hugr}; use crate::{CliError, HugrArgs}; @@ -64,8 +63,7 @@ impl HugrArgs { for ext in &self.extensions { let f = std::fs::File::open(ext)?; let ext: Extension = serde_json::from_reader(f)?; - reg.register_updated(ext) - .map_err(PackageValidationError::Extension)?; + reg.register_updated(ext); } package.update_validate(&mut reg)?; diff --git a/hugr-core/src/builder/circuit.rs b/hugr-core/src/builder/circuit.rs index 112cb83fb..43c106209 100644 --- a/hugr-core/src/builder/circuit.rs +++ b/hugr-core/src/builder/circuit.rs @@ -243,7 +243,7 @@ mod test { use super::*; use cool_asserts::assert_matches; - use crate::extension::{ExtensionId, ExtensionSet}; + use crate::extension::{ExtensionId, ExtensionSet, PRELUDE_REGISTRY}; use crate::std_extensions::arithmetic::float_types::{self, ConstF64}; use crate::utils::test_quantum_extension::{ self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64, @@ -298,8 +298,18 @@ mod test { #[test] fn with_nonlinear_and_outputs() { let my_ext_name: ExtensionId = "MyExt".try_into().unwrap(); - let mut my_ext = Extension::new_test(my_ext_name.clone()); - let my_custom_op = my_ext.simple_ext_op("MyOp", Signature::new(vec![QB, NAT], vec![QB])); + let my_ext = Extension::new_test_arc(my_ext_name.clone(), |ext, extension_ref| { + ext.add_op( + "MyOp".into(), + "".to_string(), + Signature::new(vec![QB, NAT], vec![QB]), + extension_ref, + ) + .unwrap(); + }); + let my_custom_op = my_ext + .instantiate_extension_op("MyOp", [], &PRELUDE_REGISTRY) + .unwrap(); let build_res = build_main( Signature::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]) diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index f433a3482..b390362a6 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -443,10 +443,10 @@ impl<'a> Context<'a> { let poly_func_type = match opdef.signature_func() { SignatureFunc::PolyFuncType(poly_func_type) => poly_func_type, - _ => return self.make_named_global_ref(opdef.extension(), opdef.name()), + _ => return self.make_named_global_ref(opdef.extension_id(), opdef.name()), }; - let key = (opdef.extension().clone(), opdef.name().clone()); + let key = (opdef.extension_id().clone(), opdef.name().clone()); let entry = self.decl_operations.entry(key); let node = match entry { @@ -467,7 +467,7 @@ impl<'a> Context<'a> { }; let decl = self.with_local_scope(node, |this| { - let name = this.make_qualified_name(opdef.extension(), opdef.name()); + let name = this.make_qualified_name(opdef.extension_id(), opdef.name()); let (params, constraints, r#type) = this.export_poly_func_type(poly_func_type); let decl = this.bump.alloc(model::OperationDecl { name, diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 4d22ba7f0..6d30d635e 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -7,7 +7,8 @@ pub use semver::Version; use std::collections::btree_map; use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Debug, Display, Formatter}; -use std::sync::Arc; +use std::mem; +use std::sync::{Arc, Weak}; use thiserror::Error; @@ -103,10 +104,7 @@ impl ExtensionRegistry { /// /// Takes an Arc to the extension. To avoid cloning Arcs unless necessary, see /// [`ExtensionRegistry::register_updated_ref`]. - pub fn register_updated( - &mut self, - extension: impl Into>, - ) -> Result<(), ExtensionRegistryError> { + pub fn register_updated(&mut self, extension: impl Into>) { let extension = extension.into(); match self.0.entry(extension.name().clone()) { btree_map::Entry::Occupied(mut prev) => { @@ -118,7 +116,6 @@ impl ExtensionRegistry { ve.insert(extension); } } - Ok(()) } /// Registers a new extension to the registry, keeping most up to date if @@ -130,10 +127,7 @@ impl ExtensionRegistry { /// /// Clones the Arc only when required. For no-cloning version see /// [`ExtensionRegistry::register_updated`]. - pub fn register_updated_ref( - &mut self, - extension: &Arc, - ) -> Result<(), ExtensionRegistryError> { + pub fn register_updated_ref(&mut self, extension: &Arc) { match self.0.entry(extension.name().clone()) { btree_map::Entry::Occupied(mut prev) => { if prev.get().version() < extension.version() { @@ -144,7 +138,6 @@ impl ExtensionRegistry { ve.insert(extension.clone()); } } - Ok(()) } /// Returns the number of extensions in the registry. @@ -335,6 +328,45 @@ impl ExtensionValue { pub type ExtensionId = IdentList; /// A extension is a set of capabilities required to execute a graph. +/// +/// These are normally defined once and shared across multiple graphs and +/// operations wrapped in [`Arc`]s inside [`ExtensionRegistry`]. +/// +/// # Example +/// +/// The following example demonstrates how to define a new extension with a +/// custom operation and a custom type. +/// +/// When using `arc`s, the extension can only be modified at creation time. The +/// defined operations and types keep a [`Weak`] reference to their extension. We provide a +/// helper method [`Extension::new_arc`] to aid their definition. +/// +/// ``` +/// # use hugr_core::types::Signature; +/// # use hugr_core::extension::{Extension, ExtensionId, Version}; +/// # use hugr_core::extension::{TypeDefBound}; +/// Extension::new_arc( +/// ExtensionId::new_unchecked("my.extension"), +/// Version::new(0, 1, 0), +/// |ext, extension_ref| { +/// // Add a custom type definition +/// ext.add_type( +/// "MyType".into(), +/// vec![], // No type parameters +/// "Some type".into(), +/// TypeDefBound::any(), +/// extension_ref, +/// ); +/// // Add a custom operation +/// ext.add_op( +/// "MyOp".into(), +/// "Some operation".into(), +/// Signature::new_endo(vec![]), +/// extension_ref, +/// ); +/// }, +/// ); +/// ``` #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct Extension { /// Extension version, follows semver. @@ -361,6 +393,12 @@ pub struct Extension { impl Extension { /// Creates a new extension with the given name. + /// + /// In most cases extensions are contained inside an [`Arc`] so that they + /// can be shared across hugr instances and operation definitions. + /// + /// See [`Extension::new_arc`] for a more ergonomic way to create boxed + /// extensions. pub fn new(name: ExtensionId, version: Version) -> Self { Self { name, @@ -372,14 +410,63 @@ impl Extension { } } - /// Extend the requirements of this extension with another set of extensions. - pub fn with_reqs(self, extension_reqs: impl Into) -> Self { - Self { - extension_reqs: self.extension_reqs.union(extension_reqs.into()), - ..self + /// Creates a new extension wrapped in an [`Arc`]. + /// + /// The closure lets us use a weak reference to the arc while the extension + /// is being built. This is necessary for calling [`Extension::add_op`] and + /// [`Extension::add_type`]. + pub fn new_arc( + name: ExtensionId, + version: Version, + init: impl FnOnce(&mut Extension, &Weak), + ) -> Arc { + Arc::new_cyclic(|extension_ref| { + let mut ext = Self::new(name, version); + init(&mut ext, extension_ref); + ext + }) + } + + /// Creates a new extension wrapped in an [`Arc`], using a fallible + /// initialization function. + /// + /// The closure lets us use a weak reference to the arc while the extension + /// is being built. This is necessary for calling [`Extension::add_op`] and + /// [`Extension::add_type`]. + pub fn try_new_arc( + name: ExtensionId, + version: Version, + init: impl FnOnce(&mut Extension, &Weak) -> Result<(), E>, + ) -> Result, E> { + // Annoying hack around not having `Arc::try_new_cyclic` that can return + // a Result. + // https://github.com/rust-lang/rust/issues/75861#issuecomment-980455381 + // + // When there is an error, we store it in `error` and return it at the + // end instead of the partially-initialized extension. + let mut error = None; + let ext = Arc::new_cyclic(|extension_ref| { + let mut ext = Self::new(name, version); + match init(&mut ext, extension_ref) { + Ok(_) => ext, + Err(e) => { + error = Some(e); + ext + } + } + }); + match error { + Some(e) => Err(e), + None => Ok(ext), } } + /// Extend the requirements of this extension with another set of extensions. + pub fn add_requirements(&mut self, extension_reqs: impl Into) { + let reqs = mem::take(&mut self.extension_reqs); + self.extension_reqs = reqs.union(extension_reqs.into()); + } + /// Allows read-only access to the operations in this Extension pub fn get_op(&self, name: &OpNameRef) -> Option<&Arc> { self.operations.get(name) @@ -634,20 +721,22 @@ pub mod test { impl Extension { /// Create a new extension for testing, with a 0 version. - pub(crate) fn new_test(name: ExtensionId) -> Self { - Self::new(name, Version::new(0, 0, 0)) + pub(crate) fn new_test_arc( + name: ExtensionId, + init: impl FnOnce(&mut Extension, &Weak), + ) -> Arc { + Self::new_arc(name, Version::new(0, 0, 0), init) } - /// Add a simple OpDef to the extension and return an extension op for it. - /// No description, no type parameters. - pub(crate) fn simple_ext_op( - &mut self, - name: &str, - signature: impl Into, - ) -> ExtensionOp { - self.add_op(name.into(), "".to_string(), signature).unwrap(); - self.instantiate_extension_op(name, [], &PRELUDE_REGISTRY) - .unwrap() + /// Create a new extension for testing, with a 0 version. + pub(crate) fn try_new_test_arc( + name: ExtensionId, + init: impl FnOnce( + &mut Extension, + &Weak, + ) -> Result<(), Box>, + ) -> Result, Box> { + Self::try_new_arc(name, Version::new(0, 0, 0), init) } } @@ -680,14 +769,14 @@ pub mod test { ); // register with update works - reg_ref.register_updated_ref(&ext1_1).unwrap(); - reg.register_updated(ext1_1.clone()).unwrap(); + reg_ref.register_updated_ref(&ext1_1); + reg.register_updated(ext1_1.clone()); assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0)); assert_eq!(®, ®_ref); // register with lower version does not change version - reg_ref.register_updated_ref(&ext1_2).unwrap(); - reg.register_updated(ext1_2.clone()).unwrap(); + reg_ref.register_updated_ref(&ext1_2); + reg.register_updated(ext1_2.clone()); assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0)); assert_eq!(®, ®_ref); diff --git a/hugr-core/src/extension/declarative.rs b/hugr-core/src/extension/declarative.rs index c81414c9f..94d557895 100644 --- a/hugr-core/src/extension/declarative.rs +++ b/hugr-core/src/extension/declarative.rs @@ -29,6 +29,7 @@ mod types; use std::fs::File; use std::path::Path; +use std::sync::Arc; use crate::extension::prelude::PRELUDE_ID; use crate::ops::OpName; @@ -150,19 +151,24 @@ impl ExtensionDeclaration { &self, imports: &ExtensionSet, ctx: DeclarationContext<'_>, - ) -> Result { - let mut ext = Extension::new(self.name.clone(), crate::extension::Version::new(0, 0, 0)) - .with_reqs(imports.clone()); - - for t in &self.types { - t.register(&mut ext, ctx)?; - } - - for o in &self.operations { - o.register(&mut ext, ctx)?; - } - - Ok(ext) + ) -> Result, ExtensionDeclarationError> { + Extension::try_new_arc( + self.name.clone(), + // TODO: Get the version as a parameter. + crate::extension::Version::new(0, 0, 0), + |ext, extension_ref| { + for t in &self.types { + t.register(ext, ctx, extension_ref)?; + } + + for o in &self.operations { + o.register(ext, ctx, extension_ref)?; + } + ext.add_requirements(imports.clone()); + + Ok(()) + }, + ) } } diff --git a/hugr-core/src/extension/declarative/ops.rs b/hugr-core/src/extension/declarative/ops.rs index 8bd769e10..39e688a6b 100644 --- a/hugr-core/src/extension/declarative/ops.rs +++ b/hugr-core/src/extension/declarative/ops.rs @@ -8,6 +8,7 @@ //! [`ExtensionSetDeclaration`]: super::ExtensionSetDeclaration use std::collections::HashMap; +use std::sync::Weak; use serde::{Deserialize, Serialize}; use smol_str::SmolStr; @@ -55,10 +56,14 @@ pub(super) struct OperationDeclaration { impl OperationDeclaration { /// Register this operation in the given extension. + /// + /// Requires a [`Weak`] reference to the extension defining the operation. + /// This method is intended to be used inside the closure passed to [`Extension::new_arc`]. pub fn register<'ext>( &self, ext: &'ext mut Extension, ctx: DeclarationContext<'_>, + extension_ref: &Weak, ) -> Result<&'ext mut OpDef, ExtensionDeclarationError> { // We currently only support explicit signatures. // @@ -88,7 +93,12 @@ impl OperationDeclaration { let signature_func: SignatureFunc = signature.make_signature(ext, ctx, ¶ms)?; - let op_def = ext.add_op(self.name.clone(), self.description.clone(), signature_func)?; + let op_def = ext.add_op( + self.name.clone(), + self.description.clone(), + signature_func, + extension_ref, + )?; for (k, v) in &self.misc { op_def.add_misc(k, v.clone()); diff --git a/hugr-core/src/extension/declarative/types.rs b/hugr-core/src/extension/declarative/types.rs index 10b6e41a0..e426c69f2 100644 --- a/hugr-core/src/extension/declarative/types.rs +++ b/hugr-core/src/extension/declarative/types.rs @@ -7,6 +7,8 @@ //! [specification]: https://github.com/CQCL/hugr/blob/main/specification/hugr.md#declarative-format //! [`ExtensionSetDeclaration`]: super::ExtensionSetDeclaration +use std::sync::Weak; + use crate::extension::{TypeDef, TypeDefBound}; use crate::types::type_param::TypeParam; use crate::types::{TypeBound, TypeName}; @@ -49,10 +51,14 @@ impl TypeDeclaration { /// /// Types in the definition will be resolved using the extensions in `scope` /// and the current extension. + /// + /// Requires a [`Weak`] reference to the extension defining the operation. + /// This method is intended to be used inside the closure passed to [`Extension::new_arc`]. pub fn register<'ext>( &self, ext: &'ext mut Extension, ctx: DeclarationContext<'_>, + extension_ref: &Weak, ) -> Result<&'ext TypeDef, ExtensionDeclarationError> { let params = self .params @@ -64,6 +70,7 @@ impl TypeDeclaration { params, self.description.clone(), self.bound.into(), + extension_ref, )?; Ok(type_def) } diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 5e74b9e9c..6f33cf3ef 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -2,7 +2,7 @@ use std::cmp::min; use std::collections::btree_map::Entry; use std::collections::HashMap; use std::fmt::{Debug, Formatter}; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use super::{ ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, @@ -302,6 +302,9 @@ impl Debug for LowerFunc { pub struct OpDef { /// The unique Extension owning this OpDef (of which this OpDef is a member) extension: ExtensionId, + /// A weak reference to the extension defining this operation. + #[serde(skip)] + extension_ref: Weak, /// Unique identifier of the operation. Used to look up OpDefs in the registry /// when deserializing nodes (which store only the name). name: OpName, @@ -394,11 +397,16 @@ impl OpDef { &self.name } - /// Returns a reference to the extension of this [`OpDef`]. - pub fn extension(&self) -> &ExtensionId { + /// Returns a reference to the extension id of this [`OpDef`]. + pub fn extension_id(&self) -> &ExtensionId { &self.extension } + /// Returns a weak reference to the extension defining this operation. + pub fn extension(&self) -> Weak { + self.extension_ref.clone() + } + /// Returns a reference to the description of this [`OpDef`]. pub fn description(&self) -> &str { self.description.as_ref() @@ -467,15 +475,41 @@ impl Extension { /// Add an operation definition to the extension. Must be a type scheme /// (defined by a [`PolyFuncTypeRV`]), a type scheme along with binary /// validation for type arguments ([`CustomValidator`]), or a custom binary - /// function for computing the signature given type arguments (`impl [CustomSignatureFunc]`). + /// function for computing the signature given type arguments (implementing + /// `[CustomSignatureFunc]`). + /// + /// This method requires a [`Weak`] reference to the [`Arc`] containing the + /// extension being defined. The intended way to call this method is inside + /// the closure passed to [`Extension::new_arc`] when defining the extension. + /// + /// # Example + /// + /// ``` + /// # use hugr_core::types::Signature; + /// # use hugr_core::extension::{Extension, ExtensionId, Version}; + /// Extension::new_arc( + /// ExtensionId::new_unchecked("my.extension"), + /// Version::new(0, 1, 0), + /// |ext, extension_ref| { + /// ext.add_op( + /// "MyOp".into(), + /// "Some operation".into(), + /// Signature::new_endo(vec![]), + /// extension_ref, + /// ); + /// }, + /// ); + /// ``` pub fn add_op( &mut self, name: OpName, description: String, signature_func: impl Into, + extension_ref: &Weak, ) -> Result<&mut OpDef, ExtensionBuildError> { let op = OpDef { extension: self.name.clone(), + extension_ref: extension_ref.clone(), name, description, signature_func: signature_func.into(), @@ -544,6 +578,7 @@ pub(super) mod test { fn eq(&self, other: &Self) -> bool { let OpDef { extension, + extension_ref: _, name, description, misc, @@ -553,6 +588,7 @@ pub(super) mod test { } = &self.0; let OpDef { extension: other_extension, + extension_ref: _, name: other_name, description: other_description, misc: other_misc, @@ -601,25 +637,28 @@ pub(super) mod test { #[test] fn op_def_with_type_scheme() -> Result<(), Box> { let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap(); - let mut e = Extension::new_test(EXT_ID); - const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; - let list_of_var = - Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?); const OP_NAME: OpName = OpName::new_inline("Reverse"); - let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo(vec![list_of_var])); - - let def = e.add_op(OP_NAME, "desc".into(), type_scheme)?; - def.add_lower_func(LowerFunc::FixedHugr { - extensions: ExtensionSet::new(), - hugr: crate::builder::test::simple_dfg_hugr(), // this is nonsense, but we are not testing the actual lowering here - }); - def.add_misc("key", Default::default()); - assert_eq!(def.description(), "desc"); - assert_eq!(def.lower_funcs.len(), 1); - assert_eq!(def.misc.len(), 1); - - let reg = - ExtensionRegistry::try_new([PRELUDE.clone(), EXTENSION.clone(), e.into()]).unwrap(); + + let ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { + const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; + let list_of_var = + Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?); + let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo(vec![list_of_var])); + + let def = ext.add_op(OP_NAME, "desc".into(), type_scheme, extension_ref)?; + def.add_lower_func(LowerFunc::FixedHugr { + extensions: ExtensionSet::new(), + hugr: crate::builder::test::simple_dfg_hugr(), // this is nonsense, but we are not testing the actual lowering here + }); + def.add_misc("key", Default::default()); + assert_eq!(def.description(), "desc"); + assert_eq!(def.lower_funcs.len(), 1); + assert_eq!(def.misc.len(), 1); + + Ok(()) + })?; + + let reg = ExtensionRegistry::try_new([PRELUDE.clone(), EXTENSION.clone(), ext]).unwrap(); let e = reg.get(&EXT_ID).unwrap(); let list_usize = @@ -666,60 +705,63 @@ pub(super) mod test { MAX_NAT } } - let mut e = Extension::new_test(EXT_ID); - let def: &mut crate::extension::OpDef = - e.add_op("MyOp".into(), "".to_string(), SigFun())?; - - // Base case, no type variables: - let args = [TypeArg::BoundedNat { n: 3 }, USIZE_T.into()]; - assert_eq!( - def.compute_signature(&args, &PRELUDE_REGISTRY), - Ok( - Signature::new(vec![USIZE_T; 3], vec![Type::new_tuple(vec![USIZE_T; 3])]) - .with_extension_delta(EXT_ID) - ) - ); - assert_eq!(def.validate_args(&args, &PRELUDE_REGISTRY, &[]), Ok(())); - - // Second arg may be a variable (substitutable) - let tyvar = Type::new_var_use(0, TypeBound::Copyable); - let tyvars: Vec = vec![tyvar.clone(); 3]; - let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()]; - assert_eq!( - def.compute_signature(&args, &PRELUDE_REGISTRY), - Ok( - Signature::new(tyvars.clone(), vec![Type::new_tuple(tyvars)]) - .with_extension_delta(EXT_ID) - ) - ); - def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeBound::Copyable.into()]) - .unwrap(); - - // quick sanity check that we are validating the args - note changed bound: - assert_eq!( - def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeBound::Any.into()]), - Err(SignatureError::TypeVarDoesNotMatchDeclaration { - actual: TypeBound::Any.into(), - cached: TypeBound::Copyable.into() - }) - ); - - // First arg must be concrete, not a variable - let kind = TypeParam::bounded_nat(NonZeroU64::new(5).unwrap()); - let args = [TypeArg::new_var_use(0, kind.clone()), USIZE_T.into()]; - // We can't prevent this from getting into our compute_signature implementation: - assert_eq!( - def.compute_signature(&args, &PRELUDE_REGISTRY), - Err(SignatureError::InvalidTypeArgs) - ); - // But validation rules it out, even when the variable is declared: - assert_eq!( - def.validate_args(&args, &PRELUDE_REGISTRY, &[kind]), - Err(SignatureError::FreeTypeVar { - idx: 0, - num_decls: 0 - }) - ); + let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { + let def: &mut crate::extension::OpDef = + ext.add_op("MyOp".into(), "".to_string(), SigFun(), extension_ref)?; + + // Base case, no type variables: + let args = [TypeArg::BoundedNat { n: 3 }, USIZE_T.into()]; + assert_eq!( + def.compute_signature(&args, &PRELUDE_REGISTRY), + Ok( + Signature::new(vec![USIZE_T; 3], vec![Type::new_tuple(vec![USIZE_T; 3])]) + .with_extension_delta(EXT_ID) + ) + ); + assert_eq!(def.validate_args(&args, &PRELUDE_REGISTRY, &[]), Ok(())); + + // Second arg may be a variable (substitutable) + let tyvar = Type::new_var_use(0, TypeBound::Copyable); + let tyvars: Vec = vec![tyvar.clone(); 3]; + let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()]; + assert_eq!( + def.compute_signature(&args, &PRELUDE_REGISTRY), + Ok( + Signature::new(tyvars.clone(), vec![Type::new_tuple(tyvars)]) + .with_extension_delta(EXT_ID) + ) + ); + def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeBound::Copyable.into()]) + .unwrap(); + + // quick sanity check that we are validating the args - note changed bound: + assert_eq!( + def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeBound::Any.into()]), + Err(SignatureError::TypeVarDoesNotMatchDeclaration { + actual: TypeBound::Any.into(), + cached: TypeBound::Copyable.into() + }) + ); + + // First arg must be concrete, not a variable + let kind = TypeParam::bounded_nat(NonZeroU64::new(5).unwrap()); + let args = [TypeArg::new_var_use(0, kind.clone()), USIZE_T.into()]; + // We can't prevent this from getting into our compute_signature implementation: + assert_eq!( + def.compute_signature(&args, &PRELUDE_REGISTRY), + Err(SignatureError::InvalidTypeArgs) + ); + // But validation rules it out, even when the variable is declared: + assert_eq!( + def.validate_args(&args, &PRELUDE_REGISTRY, &[kind]), + Err(SignatureError::FreeTypeVar { + idx: 0, + num_decls: 0 + }) + ); + + Ok(()) + })?; Ok(()) } @@ -728,34 +770,37 @@ pub(super) mod test { fn type_scheme_instantiate_var() -> Result<(), Box> { // Check that we can instantiate a PolyFuncTypeRV-scheme with an (external) // type variable - let mut e = Extension::new_test(EXT_ID); - let def = e.add_op( - "SimpleOp".into(), - "".into(), - PolyFuncTypeRV::new( - vec![TypeBound::Any.into()], - Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), - ), - )?; - let tv = Type::new_var_use(1, TypeBound::Copyable); - let args = [TypeArg::Type { ty: tv.clone() }]; - let decls = [TypeParam::Extensions, TypeBound::Copyable.into()]; - def.validate_args(&args, &EMPTY_REG, &decls).unwrap(); - assert_eq!( - def.compute_signature(&args, &EMPTY_REG), - Ok(Signature::new_endo(tv).with_extension_delta(EXT_ID)) - ); - // But not with an external row variable - let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into(); - assert_eq!( - def.compute_signature(&[arg.clone()], &EMPTY_REG), - Err(SignatureError::TypeArgMismatch( - TypeArgError::TypeMismatch { - param: TypeBound::Any.into(), - arg - } - )) - ); + let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { + let def = ext.add_op( + "SimpleOp".into(), + "".into(), + PolyFuncTypeRV::new( + vec![TypeBound::Any.into()], + Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), + ), + extension_ref, + )?; + let tv = Type::new_var_use(1, TypeBound::Copyable); + let args = [TypeArg::Type { ty: tv.clone() }]; + let decls = [TypeParam::Extensions, TypeBound::Copyable.into()]; + def.validate_args(&args, &EMPTY_REG, &decls).unwrap(); + assert_eq!( + def.compute_signature(&args, &EMPTY_REG), + Ok(Signature::new_endo(tv).with_extension_delta(EXT_ID)) + ); + // But not with an external row variable + let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into(); + assert_eq!( + def.compute_signature(&[arg.clone()], &EMPTY_REG), + Err(SignatureError::TypeArgMismatch( + TypeArgError::TypeMismatch { + param: TypeBound::Any.into(), + arg + } + )) + ); + Ok(()) + })?; Ok(()) } @@ -763,33 +808,39 @@ pub(super) mod test { fn instantiate_extension_delta() -> Result<(), Box> { use crate::extension::prelude::{BOOL_T, PRELUDE_REGISTRY}; - let mut e = Extension::new_test(EXT_ID); - - let params: Vec = vec![TypeParam::Extensions]; - let db_set = ExtensionSet::type_var(0); - let fun_ty = Signature::new_endo(BOOL_T).with_extension_delta(db_set); + let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { + let params: Vec = vec![TypeParam::Extensions]; + let db_set = ExtensionSet::type_var(0); + let fun_ty = Signature::new_endo(BOOL_T).with_extension_delta(db_set); + + let def = ext.add_op( + "SimpleOp".into(), + "".into(), + PolyFuncTypeRV::new(params.clone(), fun_ty), + extension_ref, + )?; + + // Concrete extension set + let es = ExtensionSet::singleton(&EXT_ID); + let exp_fun_ty = Signature::new_endo(BOOL_T).with_extension_delta(es.clone()); + let args = [TypeArg::Extensions { es }]; + + def.validate_args(&args, &PRELUDE_REGISTRY, ¶ms) + .unwrap(); + assert_eq!( + def.compute_signature(&args, &PRELUDE_REGISTRY), + Ok(exp_fun_ty) + ); + + Ok(()) + })?; - let def = e.add_op( - "SimpleOp".into(), - "".into(), - PolyFuncTypeRV::new(params.clone(), fun_ty), - )?; - - // Concrete extension set - let es = ExtensionSet::singleton(&EXT_ID); - let exp_fun_ty = Signature::new_endo(BOOL_T).with_extension_delta(es.clone()); - let args = [TypeArg::Extensions { es }]; - - def.validate_args(&args, &PRELUDE_REGISTRY, ¶ms) - .unwrap(); - assert_eq!( - def.compute_signature(&args, &PRELUDE_REGISTRY), - Ok(exp_fun_ty) - ); Ok(()) } mod proptest { + use std::sync::Weak; + use super::SimpleOpDef; use ::proptest::prelude::*; @@ -846,6 +897,8 @@ pub(super) mod test { |(extension, name, description, misc, signature_func, lower_funcs)| { Self::new(OpDef { extension, + // Use a dead weak reference. Trying to access the extension will always return None. + extension_ref: Weak::default(), name, description, misc, diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index e7192f6ee..691a9cfb6 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -41,75 +41,80 @@ pub const PRELUDE_ID: ExtensionId = ExtensionId::new_unchecked("prelude"); pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); lazy_static! { static ref PRELUDE_DEF: Arc = { - let mut prelude = Extension::new(PRELUDE_ID, VERSION); - prelude - .add_type( - TypeName::new_inline("usize"), - vec![], - "usize".into(), - TypeDefBound::copyable(), - ) - .unwrap(); - prelude.add_type( - STRING_TYPE_NAME, - vec![], - "string".into(), - TypeDefBound::copyable(), - ) - .unwrap(); - prelude.add_op( - PRINT_OP_ID, - "Print the string to standard output".to_string(), - Signature::new(type_row![STRING_TYPE], type_row![]), - ) - .unwrap(); - prelude.add_type( - TypeName::new_inline(ARRAY_TYPE_NAME), - vec![ TypeParam::max_nat(), TypeBound::Any.into()], - "array".into(), - TypeDefBound::from_params(vec![1] ), - ) - .unwrap(); - - prelude - .add_type( - TypeName::new_inline("qubit"), - vec![], - "qubit".into(), - TypeDefBound::any(), - ) - .unwrap(); - prelude - .add_type( - ERROR_TYPE_NAME, - vec![], - "Simple opaque error type.".into(), - TypeDefBound::copyable(), - ) - .unwrap(); - - - prelude - .add_op( - PANIC_OP_ID, - "Panic with input error".to_string(), - PolyFuncTypeRV::new( - [TypeParam::new_list(TypeBound::Any), TypeParam::new_list(TypeBound::Any)], - FuncValueType::new( - vec![TypeRV::new_extension(ERROR_CUSTOM_TYPE), TypeRV::new_row_var_use(0, TypeBound::Any)], - vec![TypeRV::new_row_var_use(1, TypeBound::Any)], - ), - ), - ) - .unwrap(); - - TupleOpDef::load_all_ops(&mut prelude).unwrap(); - NoopDef.add_to_extension(&mut prelude).unwrap(); - LiftDef.add_to_extension(&mut prelude).unwrap(); - array::ArrayOpDef::load_all_ops(&mut prelude).unwrap(); - array::ArrayScanDef.add_to_extension(&mut prelude).unwrap(); - - Arc::new(prelude) + Extension::new_arc(PRELUDE_ID, VERSION, |prelude, extension_ref| { + prelude + .add_type( + TypeName::new_inline("usize"), + vec![], + "usize".into(), + TypeDefBound::copyable(), + extension_ref, + ) + .unwrap(); + prelude.add_type( + STRING_TYPE_NAME, + vec![], + "string".into(), + TypeDefBound::copyable(), + extension_ref, + ) + .unwrap(); + prelude.add_op( + PRINT_OP_ID, + "Print the string to standard output".to_string(), + Signature::new(type_row![STRING_TYPE], type_row![]), + extension_ref, + ) + .unwrap(); + prelude.add_type( + TypeName::new_inline(ARRAY_TYPE_NAME), + vec![ TypeParam::max_nat(), TypeBound::Any.into()], + "array".into(), + TypeDefBound::from_params(vec![1] ), + extension_ref, + ) + .unwrap(); + + prelude + .add_type( + TypeName::new_inline("qubit"), + vec![], + "qubit".into(), + TypeDefBound::any(), + extension_ref, + ) + .unwrap(); + prelude + .add_type( + ERROR_TYPE_NAME, + vec![], + "Simple opaque error type.".into(), + TypeDefBound::copyable(), + extension_ref, + ) + .unwrap(); + + prelude + .add_op( + PANIC_OP_ID, + "Panic with input error".to_string(), + PolyFuncTypeRV::new( + [TypeParam::new_list(TypeBound::Any), TypeParam::new_list(TypeBound::Any)], + FuncValueType::new( + vec![TypeRV::new_extension(ERROR_CUSTOM_TYPE), TypeRV::new_row_var_use(0, TypeBound::Any)], + vec![TypeRV::new_row_var_use(1, TypeBound::Any)], + ), + ), + extension_ref, + ) + .unwrap(); + + TupleOpDef::load_all_ops(prelude, extension_ref).unwrap(); + NoopDef.add_to_extension(prelude, extension_ref).unwrap(); + LiftDef.add_to_extension(prelude, extension_ref).unwrap(); + array::ArrayOpDef::load_all_ops(prelude, extension_ref).unwrap(); + array::ArrayScanDef.add_to_extension(prelude, extension_ref).unwrap(); + }) }; /// An extension registry containing only the prelude pub static ref PRELUDE_REGISTRY: ExtensionRegistry = @@ -528,7 +533,7 @@ impl MakeOpDef for TupleOpDef { } fn from_def(op_def: &OpDef) -> Result { - try_from_name(op_def.name(), op_def.extension()) + try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { @@ -695,7 +700,7 @@ impl MakeOpDef for NoopDef { } fn from_def(op_def: &OpDef) -> Result { - try_from_name(op_def.name(), op_def.extension()) + try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { @@ -805,7 +810,7 @@ impl MakeOpDef for LiftDef { } fn from_def(op_def: &OpDef) -> Result { - try_from_name(op_def.name(), op_def.extension()) + try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { diff --git a/hugr-core/src/extension/prelude/array.rs b/hugr-core/src/extension/prelude/array.rs index 6013039d4..c419a67c7 100644 --- a/hugr-core/src/extension/prelude/array.rs +++ b/hugr-core/src/extension/prelude/array.rs @@ -1,4 +1,5 @@ use std::str::FromStr; +use std::sync::Weak; use itertools::Itertools; use strum_macros::EnumIter; @@ -180,7 +181,7 @@ impl MakeOpDef for ArrayOpDef { where Self: Sized, { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn signature(&self) -> SignatureFunc { @@ -216,9 +217,10 @@ impl MakeOpDef for ArrayOpDef { fn add_to_extension( &self, extension: &mut Extension, + extension_ref: &Weak, ) -> Result<(), crate::extension::ExtensionBuildError> { let sig = self.signature_from_def(extension.get_type(ARRAY_TYPE_NAME).unwrap()); - let def = extension.add_op(self.name(), self.description(), sig)?; + let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; self.post_opdef(def); @@ -394,7 +396,7 @@ impl MakeOpDef for ArrayScanDef { where Self: Sized, { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn signature(&self) -> SignatureFunc { @@ -421,9 +423,10 @@ impl MakeOpDef for ArrayScanDef { fn add_to_extension( &self, extension: &mut Extension, + extension_ref: &Weak, ) -> Result<(), crate::extension::ExtensionBuildError> { let sig = self.signature_from_def(extension.get_type(ARRAY_TYPE_NAME).unwrap()); - let def = extension.add_op(self.name(), self.description(), sig)?; + let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; self.post_opdef(def); diff --git a/hugr-core/src/extension/simple_op.rs b/hugr-core/src/extension/simple_op.rs index c338a693d..6d1c678c5 100644 --- a/hugr-core/src/extension/simple_op.rs +++ b/hugr-core/src/extension/simple_op.rs @@ -1,5 +1,7 @@ //! A trait that enum for op definitions that gathers up some shared functionality. +use std::sync::Weak; + use strum::IntoEnumIterator; use crate::ops::{ExtensionOp, OpName, OpNameRef}; @@ -67,8 +69,20 @@ pub trait MakeOpDef: NamedOp { /// Add an operation implemented as an [MakeOpDef], which can provide the data /// required to define an [OpDef], to an extension. - fn add_to_extension(&self, extension: &mut Extension) -> Result<(), ExtensionBuildError> { - let def = extension.add_op(self.name(), self.description(), self.signature())?; + /// + /// Requires a [`Weak`] reference to the extension defining the operation. + /// This method is intended to be used inside the closure passed to [`Extension::new_arc`]. + fn add_to_extension( + &self, + extension: &mut Extension, + extension_ref: &Weak, + ) -> Result<(), ExtensionBuildError> { + let def = extension.add_op( + self.name(), + self.description(), + self.signature(), + extension_ref, + )?; self.post_opdef(def); @@ -77,12 +91,18 @@ pub trait MakeOpDef: NamedOp { /// Load all variants of an enum of op definitions in to an extension as op defs. /// See [strum::IntoEnumIterator]. - fn load_all_ops(extension: &mut Extension) -> Result<(), ExtensionBuildError> + /// + /// Requires a [`Weak`] reference to the extension defining the operation. + /// This method is intended to be used inside the closure passed to [`Extension::new_arc`]. + fn load_all_ops( + extension: &mut Extension, + extension_ref: &Weak, + ) -> Result<(), ExtensionBuildError> where Self: IntoEnumIterator, { for op in Self::iter() { - op.add_to_extension(extension)?; + op.add_to_extension(extension, extension_ref)?; } Ok(()) } @@ -316,9 +336,11 @@ mod test { lazy_static! { static ref EXT: Arc = { - let mut e = Extension::new_test(EXT_ID.clone()); - DummyEnum::Dumb.add_to_extension(&mut e).unwrap(); - Arc::new(e) + Extension::new_test_arc(EXT_ID.clone(), |ext, extension_ref| { + DummyEnum::Dumb + .add_to_extension(ext, extension_ref) + .unwrap(); + }) }; static ref DUMMY_REG: ExtensionRegistry = ExtensionRegistry::try_new([EXT.clone()]).unwrap(); diff --git a/hugr-core/src/extension/type_def.rs b/hugr-core/src/extension/type_def.rs index 1affe68f0..7f0daa3ca 100644 --- a/hugr-core/src/extension/type_def.rs +++ b/hugr-core/src/extension/type_def.rs @@ -1,4 +1,5 @@ use std::collections::btree_map::Entry; +use std::sync::Weak; use super::{CustomConcrete, ExtensionBuildError}; use super::{Extension, ExtensionId, SignatureError}; @@ -56,6 +57,9 @@ impl TypeDefBound { pub struct TypeDef { /// The unique Extension owning this TypeDef (of which this TypeDef is a member) extension: ExtensionId, + /// A weak reference to the extension defining this operation. + #[serde(skip)] + extension_ref: Weak, /// The unique name of the type name: TypeName, /// Declaration of type parameters. The TypeDef must be instantiated @@ -82,9 +86,9 @@ impl TypeDef { /// This function will return an error if the type of the instance does not /// match the definition. pub fn check_custom(&self, custom: &CustomType) -> Result<(), SignatureError> { - if self.extension() != custom.parent_extension() { + if self.extension_id() != custom.parent_extension() { return Err(SignatureError::ExtensionMismatch( - self.extension().clone(), + self.extension_id().clone(), custom.parent_extension().clone(), )); } @@ -121,7 +125,7 @@ impl TypeDef { Ok(CustomType::new( self.name().clone(), args, - self.extension().clone(), + self.extension_id().clone(), bound, )) } @@ -156,22 +160,55 @@ impl TypeDef { &self.name } - fn extension(&self) -> &ExtensionId { + /// Returns a reference to the extension id of this [`TypeDef`]. + pub fn extension_id(&self) -> &ExtensionId { &self.extension } + + /// Returns a weak reference to the extension defining this type. + pub fn extension(&self) -> Weak { + self.extension_ref.clone() + } } impl Extension { /// Add an exported type to the extension. + /// + /// This method requires a [`Weak`] reference to the [`std::sync::Arc`] containing the + /// extension being defined. The intended way to call this method is inside + /// the closure passed to [`Extension::new_arc`] when defining the extension. + /// + /// # Example + /// + /// ``` + /// # use hugr_core::types::Signature; + /// # use hugr_core::extension::{Extension, ExtensionId, Version}; + /// # use hugr_core::extension::{TypeDefBound}; + /// Extension::new_arc( + /// ExtensionId::new_unchecked("my.extension"), + /// Version::new(0, 1, 0), + /// |ext, extension_ref| { + /// ext.add_type( + /// "MyType".into(), + /// vec![], // No type parameters + /// "Some type".into(), + /// TypeDefBound::any(), + /// extension_ref, + /// ); + /// }, + /// ); + /// ``` pub fn add_type( &mut self, name: TypeName, params: Vec, description: String, bound: TypeDefBound, + extension_ref: &Weak, ) -> Result<&TypeDef, ExtensionBuildError> { let ty = TypeDef { extension: self.name.clone(), + extension_ref: extension_ref.clone(), name, params, description, @@ -202,6 +239,8 @@ mod test { b: TypeBound::Copyable, }], extension: "MyRsrc".try_into().unwrap(), + // Dummy extension. Will return `None` when trying to upgrade it into an `Arc`. + extension_ref: Default::default(), description: "Some parametrised type".into(), bound: TypeDefBound::FromParams { indices: vec![0] }, }; diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/rewrite/replace.rs index 8967df9a5..e23fab7c5 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/rewrite/replace.rs @@ -463,7 +463,7 @@ mod test { use crate::std_extensions::collections::{self, list_type, ListOp}; use crate::types::{Signature, Type, TypeRow}; use crate::utils::depth; - use crate::{type_row, Direction, Hugr, HugrView, OutgoingPort}; + use crate::{type_row, Direction, Extension, Hugr, HugrView, OutgoingPort}; use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement}; @@ -638,10 +638,26 @@ mod test { #[test] fn test_invalid() { - let mut new_ext = crate::Extension::new_test("new_ext".try_into().unwrap()); - let ext_name = new_ext.name().clone(); let utou = Signature::new_endo(vec![USIZE_T]); - let mut mk_op = |s| new_ext.simple_ext_op(s, utou.clone()); + let ext = Extension::new_test_arc("new_ext".try_into().unwrap(), |ext, extension_ref| { + ext.add_op("foo".into(), "".to_string(), utou.clone(), extension_ref) + .unwrap(); + ext.add_op("bar".into(), "".to_string(), utou.clone(), extension_ref) + .unwrap(); + ext.add_op("baz".into(), "".to_string(), utou.clone(), extension_ref) + .unwrap(); + }); + let ext_name = ext.name().clone(); + let foo = ext + .instantiate_extension_op("foo", [], &PRELUDE_REGISTRY) + .unwrap(); + let bar = ext + .instantiate_extension_op("bar", [], &PRELUDE_REGISTRY) + .unwrap(); + let baz = ext + .instantiate_extension_op("baz", [], &PRELUDE_REGISTRY) + .unwrap(); + let mut h = DFGBuilder::new( Signature::new(type_row![USIZE_T, BOOL_T], type_row![USIZE_T]) .with_extension_delta(ext_name.clone()), @@ -657,23 +673,17 @@ mod test { ) .unwrap(); let mut case1 = cond.case_builder(0).unwrap(); - let foo = case1 - .add_dataflow_op(mk_op("foo"), case1.input_wires()) - .unwrap(); + let foo = case1.add_dataflow_op(foo, case1.input_wires()).unwrap(); let case1 = case1.finish_with_outputs(foo.outputs()).unwrap().node(); let mut case2 = cond.case_builder(1).unwrap(); - let bar = case2 - .add_dataflow_op(mk_op("bar"), case2.input_wires()) - .unwrap(); + let bar = case2.add_dataflow_op(bar, case2.input_wires()).unwrap(); let mut baz_dfg = case2 .dfg_builder( utou.clone().with_extension_delta(ext_name.clone()), bar.outputs(), ) .unwrap(); - let baz = baz_dfg - .add_dataflow_op(mk_op("baz"), baz_dfg.input_wires()) - .unwrap(); + let baz = baz_dfg.add_dataflow_op(baz, baz_dfg.input_wires()).unwrap(); let baz_dfg = baz_dfg.finish_with_outputs(baz.outputs()).unwrap(); let case2 = case2.finish_with_outputs(baz_dfg.outputs()).unwrap().node(); let cond = cond.finish_sub_container().unwrap(); diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index cf934e18b..97ffc3d53 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -1,5 +1,6 @@ use std::fs::File; use std::io::BufReader; +use std::sync::Arc; use cool_asserts::assert_matches; @@ -378,15 +379,17 @@ const_extension_ids! { } #[test] fn invalid_types() { - let mut e = Extension::new_test(EXT_ID); - e.add_type( - "MyContainer".into(), - vec![TypeBound::Copyable.into()], - "".into(), - TypeDefBound::any(), - ) - .unwrap(); - let reg = ExtensionRegistry::try_new([e.into(), PRELUDE.clone()]).unwrap(); + let ext = Extension::new_test_arc(EXT_ID, |ext, extension_ref| { + ext.add_type( + "MyContainer".into(), + vec![TypeBound::Copyable.into()], + "".into(), + TypeDefBound::any(), + extension_ref, + ) + .unwrap(); + }); + let reg = ExtensionRegistry::try_new([ext, PRELUDE.clone()]).unwrap(); let validate_to_sig_error = |t: CustomType| { let (h, def) = identity_hugr_with_type(Type::new_extension(t)); @@ -587,33 +590,33 @@ fn no_polymorphic_consts() -> Result<(), Box> { Ok(()) } -pub(crate) fn extension_with_eval_parallel() -> Extension { +pub(crate) fn extension_with_eval_parallel() -> Arc { let rowp = TypeParam::new_list(TypeBound::Any); - let mut e = Extension::new_test(EXT_ID); - - let inputs = TypeRV::new_row_var_use(0, TypeBound::Any); - let outputs = TypeRV::new_row_var_use(1, TypeBound::Any); - let evaled_fn = TypeRV::new_function(FuncValueType::new(inputs.clone(), outputs.clone())); - let pf = PolyFuncTypeRV::new( - [rowp.clone(), rowp.clone()], - FuncValueType::new(vec![evaled_fn, inputs], outputs), - ); - e.add_op("eval".into(), "".into(), pf).unwrap(); - - let rv = |idx| TypeRV::new_row_var_use(idx, TypeBound::Any); - let pf = PolyFuncTypeRV::new( - [rowp.clone(), rowp.clone(), rowp.clone(), rowp.clone()], - Signature::new( - vec![ - Type::new_function(FuncValueType::new(rv(0), rv(2))), - Type::new_function(FuncValueType::new(rv(1), rv(3))), - ], - Type::new_function(FuncValueType::new(vec![rv(0), rv(1)], vec![rv(2), rv(3)])), - ), - ); - e.add_op("parallel".into(), "".into(), pf).unwrap(); + Extension::new_test_arc(EXT_ID, |ext, extension_ref| { + let inputs = TypeRV::new_row_var_use(0, TypeBound::Any); + let outputs = TypeRV::new_row_var_use(1, TypeBound::Any); + let evaled_fn = TypeRV::new_function(FuncValueType::new(inputs.clone(), outputs.clone())); + let pf = PolyFuncTypeRV::new( + [rowp.clone(), rowp.clone()], + FuncValueType::new(vec![evaled_fn, inputs], outputs), + ); + ext.add_op("eval".into(), "".into(), pf, extension_ref) + .unwrap(); - e + let rv = |idx| TypeRV::new_row_var_use(idx, TypeBound::Any); + let pf = PolyFuncTypeRV::new( + [rowp.clone(), rowp.clone(), rowp.clone(), rowp.clone()], + Signature::new( + vec![ + Type::new_function(FuncValueType::new(rv(0), rv(2))), + Type::new_function(FuncValueType::new(rv(1), rv(3))), + ], + Type::new_function(FuncValueType::new(vec![rv(0), rv(1)], vec![rv(2), rv(3)])), + ), + ); + ext.add_op("parallel".into(), "".into(), pf, extension_ref) + .unwrap(); + }) } #[test] @@ -643,7 +646,7 @@ fn instantiate_row_variables() -> Result<(), Box> { let eval2 = dfb.add_dataflow_op(eval2, [par_func, a, b])?; dfb.finish_hugr_with_outputs( eval2.outputs(), - &ExtensionRegistry::try_new([PRELUDE.clone(), e.into()]).unwrap(), + &ExtensionRegistry::try_new([PRELUDE.clone(), e]).unwrap(), )?; Ok(()) } @@ -683,41 +686,44 @@ fn row_variables() -> Result<(), Box> { let par_func = fb.add_dataflow_op(par, [func_arg, id_usz])?; fb.finish_hugr_with_outputs( par_func.outputs(), - &ExtensionRegistry::try_new([PRELUDE.clone(), e.into()]).unwrap(), + &ExtensionRegistry::try_new([PRELUDE.clone(), e]).unwrap(), )?; Ok(()) } #[test] fn test_polymorphic_call() -> Result<(), Box> { - let mut e = Extension::new_test(EXT_ID); - - let params: Vec = vec![ - TypeBound::Any.into(), - TypeParam::Extensions, - TypeBound::Any.into(), - ]; - let evaled_fn = Type::new_function( - Signature::new( - Type::new_var_use(0, TypeBound::Any), - Type::new_var_use(2, TypeBound::Any), - ) - .with_extension_delta(ExtensionSet::type_var(1)), - ); - // Single-input/output version of the higher-order "eval" operation, with extension param. - // Note the extension-delta of the eval node includes that of the input function. - e.add_op( - "eval".into(), - "".into(), - PolyFuncTypeRV::new( - params.clone(), + let e = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { + let params: Vec = vec![ + TypeBound::Any.into(), + TypeParam::Extensions, + TypeBound::Any.into(), + ]; + let evaled_fn = Type::new_function( Signature::new( - vec![evaled_fn, Type::new_var_use(0, TypeBound::Any)], + Type::new_var_use(0, TypeBound::Any), Type::new_var_use(2, TypeBound::Any), ) .with_extension_delta(ExtensionSet::type_var(1)), - ), - )?; + ); + // Single-input/output version of the higher-order "eval" operation, with extension param. + // Note the extension-delta of the eval node includes that of the input function. + ext.add_op( + "eval".into(), + "".into(), + PolyFuncTypeRV::new( + params.clone(), + Signature::new( + vec![evaled_fn, Type::new_var_use(0, TypeBound::Any)], + Type::new_var_use(2, TypeBound::Any), + ) + .with_extension_delta(ExtensionSet::type_var(1)), + ), + extension_ref, + )?; + + Ok(()) + })?; fn utou(e: impl Into) -> Type { Type::new_function(Signature::new_endo(USIZE_T).with_extension_delta(e.into())) @@ -763,7 +769,7 @@ fn test_polymorphic_call() -> Result<(), Box> { f.finish_with_outputs([tup])? }; - let reg = ExtensionRegistry::try_new([e.into(), PRELUDE.clone()])?; + let reg = ExtensionRegistry::try_new([e, PRELUDE.clone()])?; let [func, tup] = d.input_wires_arr(); let call = d.call( f.handle(), diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index eec5f4d34..d7f1c2c57 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -104,7 +104,7 @@ impl ExtensionOp { /// For a non-cloning version of this operation, use [`OpaqueOp::from`]. pub fn make_opaque(&self) -> OpaqueOp { OpaqueOp { - extension: self.def.extension().clone(), + extension: self.def.extension_id().clone(), name: self.def.name().clone(), description: self.def.description().into(), args: self.args.clone(), @@ -121,7 +121,7 @@ impl From for OpaqueOp { signature, } = op; OpaqueOp { - extension: def.extension().clone(), + extension: def.extension_id().clone(), name: def.name().clone(), description: def.description().into(), args, @@ -141,7 +141,7 @@ impl Eq for ExtensionOp {} impl NamedOp for ExtensionOp { /// The name of the operation. fn name(&self) -> OpName { - qualify_name(self.def.extension(), self.def.name()) + qualify_name(self.def.extension_id(), self.def.name()) } } @@ -402,26 +402,30 @@ mod test { #[test] fn resolve_missing() { - let mut ext = Extension::new_test("ext".try_into().unwrap()); - let ext_id = ext.name().clone(); let val_name = "missing_val"; let comp_name = "missing_comp"; - let endo_sig = Signature::new_endo(BOOL_T); - ext.add_op( - val_name.into(), - "".to_string(), - SignatureFunc::MissingValidateFunc(FuncValueType::from(endo_sig.clone()).into()), - ) - .unwrap(); - ext.add_op( - comp_name.into(), - "".to_string(), - SignatureFunc::MissingComputeFunc, - ) - .unwrap(); - let registry = ExtensionRegistry::try_new([ext.into()]).unwrap(); + let ext = Extension::new_test_arc("ext".try_into().unwrap(), |ext, extension_ref| { + ext.add_op( + val_name.into(), + "".to_string(), + SignatureFunc::MissingValidateFunc(FuncValueType::from(endo_sig.clone()).into()), + extension_ref, + ) + .unwrap(); + + ext.add_op( + comp_name.into(), + "".to_string(), + SignatureFunc::MissingComputeFunc, + extension_ref, + ) + .unwrap(); + }); + let ext_id = ext.name().clone(); + + let registry = ExtensionRegistry::try_new([ext]).unwrap(); let opaque_val = OpaqueOp::new( ext_id.clone(), val_name, diff --git a/hugr-core/src/package.rs b/hugr-core/src/package.rs index cec1b2c85..aa32f24d5 100644 --- a/hugr-core/src/package.rs +++ b/hugr-core/src/package.rs @@ -99,7 +99,7 @@ impl Package { reg: &mut ExtensionRegistry, ) -> Result<(), PackageValidationError> { for ext in &self.extensions { - reg.register_updated_ref(ext)?; + reg.register_updated_ref(ext); } for hugr in self.modules.iter_mut() { hugr.update_validate(reg)?; diff --git a/hugr-core/src/std_extensions/arithmetic/conversions.rs b/hugr-core/src/std_extensions/arithmetic/conversions.rs index deb93f8c2..4d3263f09 100644 --- a/hugr-core/src/std_extensions/arithmetic/conversions.rs +++ b/hugr-core/src/std_extensions/arithmetic/conversions.rs @@ -50,7 +50,7 @@ pub enum ConvertOpDef { impl MakeOpDef for ConvertOpDef { fn from_def(op_def: &OpDef) -> Result { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { @@ -158,18 +158,15 @@ impl MakeExtensionOp for ConvertOpType { lazy_static! { /// Extension for conversions between integers and floats. pub static ref EXTENSION: Arc = { - let mut extension = Extension::new( - EXTENSION_ID, - VERSION).with_reqs( - ExtensionSet::from_iter(vec![ - super::int_types::EXTENSION_ID, - super::float_types::EXTENSION_ID, - ]), - ); - - ConvertOpDef::load_all_ops(&mut extension).unwrap(); - - Arc::new(extension) + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension.add_requirements( + ExtensionSet::from_iter(vec![ + super::int_types::EXTENSION_ID, + super::float_types::EXTENSION_ID, + ])); + + ConvertOpDef::load_all_ops(extension, extension_ref).unwrap(); + }) }; /// Registry of extensions required to validate integer operations. diff --git a/hugr-core/src/std_extensions/arithmetic/float_ops.rs b/hugr-core/src/std_extensions/arithmetic/float_ops.rs index 7d353e71a..9be2c6786 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_ops.rs @@ -50,7 +50,7 @@ pub enum FloatOps { impl MakeOpDef for FloatOps { fn from_def(op_def: &OpDef) -> Result { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { @@ -107,15 +107,10 @@ impl MakeOpDef for FloatOps { lazy_static! { /// Extension for basic float operations. pub static ref EXTENSION: Arc = { - let mut extension = Extension::new( - EXTENSION_ID, - VERSION).with_reqs( - ExtensionSet::singleton(&super::int_types::EXTENSION_ID), - ); - - FloatOps::load_all_ops(&mut extension).unwrap(); - - Arc::new(extension) + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension.add_requirements(ExtensionSet::singleton(&super::int_types::EXTENSION_ID)); + FloatOps::load_all_ops(extension, extension_ref).unwrap(); + }) }; /// Registry of extensions required to validate float operations. diff --git a/hugr-core/src/std_extensions/arithmetic/float_types.rs b/hugr-core/src/std_extensions/arithmetic/float_types.rs index ec145008f..0af5f8728 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_types.rs @@ -82,18 +82,17 @@ impl CustomConst for ConstF64 { lazy_static! { /// Extension defining the float type. pub static ref EXTENSION: Arc = { - let mut extension = Extension::new(EXTENSION_ID, VERSION); - - extension - .add_type( - FLOAT_TYPE_ID, - vec![], - "64-bit IEEE 754-2019 floating-point value".to_owned(), - TypeBound::Copyable.into(), - ) - .unwrap(); - - Arc::new(extension) + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension + .add_type( + FLOAT_TYPE_ID, + vec![], + "64-bit IEEE 754-2019 floating-point value".to_owned(), + TypeBound::Copyable.into(), + extension_ref, + ) + .unwrap(); + }) }; } #[cfg(test)] diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index 97bb247a2..132a01f35 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -104,7 +104,7 @@ pub enum IntOpDef { impl MakeOpDef for IntOpDef { fn from_def(op_def: &OpDef) -> Result { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { @@ -250,15 +250,10 @@ fn iunop_sig() -> PolyFuncTypeRV { lazy_static! { /// Extension for basic integer operations. pub static ref EXTENSION: Arc = { - let mut extension = Extension::new( - EXTENSION_ID, - VERSION).with_reqs( - ExtensionSet::singleton(&super::int_types::EXTENSION_ID) - ); - - IntOpDef::load_all_ops(&mut extension).unwrap(); - - Arc::new(extension) + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension.add_requirements(ExtensionSet::singleton(&super::int_types::EXTENSION_ID)); + IntOpDef::load_all_ops(extension, extension_ref).unwrap(); + }) }; /// Registry of extensions required to validate integer operations. diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index 3d257b9d0..82f1c27ae 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -188,18 +188,17 @@ impl CustomConst for ConstInt { /// Extension for basic integer types. pub fn extension() -> Arc { - let mut extension = Extension::new(EXTENSION_ID, VERSION); - - extension - .add_type( - INT_TYPE_ID, - vec![LOG_WIDTH_TYPE_PARAM], - "integral value of a given bit width".to_owned(), - TypeBound::Copyable.into(), - ) - .unwrap(); - - Arc::new(extension) + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension + .add_type( + INT_TYPE_ID, + vec![LOG_WIDTH_TYPE_PARAM], + "integral value of a given bit width".to_owned(), + TypeBound::Copyable.into(), + extension_ref, + ) + .unwrap(); + }) } lazy_static! { diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index 17a1b0d03..2f416c92b 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -5,7 +5,7 @@ use std::hash::{Hash, Hasher}; mod list_fold; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use itertools::Itertools; use lazy_static::lazy_static; @@ -204,7 +204,7 @@ impl ListOp { impl MakeOpDef for ListOp { fn from_def(op_def: &OpDef) -> Result { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { @@ -216,9 +216,13 @@ impl MakeOpDef for ListOp { // // This method is re-defined here since we need to pass the list type def while computing the signature, // to avoid recursive loops initializing the extension. - fn add_to_extension(&self, extension: &mut Extension) -> Result<(), ExtensionBuildError> { + fn add_to_extension( + &self, + extension: &mut Extension, + extension_ref: &Weak, + ) -> Result<(), ExtensionBuildError> { let sig = self.compute_signature(extension.get_type(&LIST_TYPENAME).unwrap()); - let def = extension.add_op(self.name(), self.description(), sig)?; + let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; self.post_opdef(def); @@ -251,20 +255,19 @@ impl MakeOpDef for ListOp { lazy_static! { /// Extension for list operations. pub static ref EXTENSION: Arc = { - let mut extension = Extension::new(EXTENSION_ID, VERSION); - - // The list type must be defined before the operations are added. - extension.add_type( - LIST_TYPENAME, - vec![ListOp::TP], - "Generic dynamically sized list of type T.".into(), - TypeDefBound::from_params(vec![0]), - ) - .unwrap(); - - ListOp::load_all_ops(&mut extension).unwrap(); + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension.add_type( + LIST_TYPENAME, + vec![ListOp::TP], + "Generic dynamically sized list of type T.".into(), + TypeDefBound::from_params(vec![0]), + extension_ref + ) + .unwrap(); - Arc::new(extension) + // The list type must be defined before the operations are added. + ListOp::load_all_ops(extension, extension_ref).unwrap(); + }) }; /// Registry of extensions required to validate list operations. @@ -392,7 +395,7 @@ mod test { assert_eq!(&ListOp::push.extension(), EXTENSION.name()); assert!(ListOp::pop.registry().contains(EXTENSION.name())); for (_, op_def) in EXTENSION.operations() { - assert_eq!(op_def.extension(), &EXTENSION_ID); + assert_eq!(op_def.extension_id(), &EXTENSION_ID); } } diff --git a/hugr-core/src/std_extensions/logic.rs b/hugr-core/src/std_extensions/logic.rs index 89f9dfa8b..4799e3f33 100644 --- a/hugr-core/src/std_extensions/logic.rs +++ b/hugr-core/src/std_extensions/logic.rs @@ -91,7 +91,7 @@ impl MakeOpDef for LogicOp { } fn from_def(op_def: &OpDef) -> Result { - try_from_name(op_def.name(), op_def.extension()) + try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { @@ -110,16 +110,16 @@ pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); /// Extension for basic logical operations. fn extension() -> Arc { - let mut extension = Extension::new(EXTENSION_ID, VERSION); - LogicOp::load_all_ops(&mut extension).unwrap(); + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + LogicOp::load_all_ops(extension, extension_ref).unwrap(); - extension - .add_value(FALSE_NAME, ops::Value::false_val()) - .unwrap(); - extension - .add_value(TRUE_NAME, ops::Value::true_val()) - .unwrap(); - Arc::new(extension) + extension + .add_value(FALSE_NAME, ops::Value::false_val()) + .unwrap(); + extension + .add_value(TRUE_NAME, ops::Value::true_val()) + .unwrap(); + }) } lazy_static! { diff --git a/hugr-core/src/std_extensions/ptr.rs b/hugr-core/src/std_extensions/ptr.rs index e3023e4b5..1822967b7 100644 --- a/hugr-core/src/std_extensions/ptr.rs +++ b/hugr-core/src/std_extensions/ptr.rs @@ -47,7 +47,7 @@ impl MakeOpDef for PtrOpDef { where Self: Sized, { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn signature(&self) -> SignatureFunc { @@ -87,17 +87,18 @@ pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); /// Extension for pointer operations. fn extension() -> Arc { - let mut extension = Extension::new(EXTENSION_ID, VERSION); - extension - .add_type( - PTR_TYPE_ID, - TYPE_PARAMS.into(), - "Standard extension pointer type.".into(), - TypeDefBound::copyable(), - ) - .unwrap(); - PtrOpDef::load_all_ops(&mut extension).unwrap(); - Arc::new(extension) + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension + .add_type( + PTR_TYPE_ID, + TYPE_PARAMS.into(), + "Standard extension pointer type.".into(), + TypeDefBound::copyable(), + extension_ref, + ) + .unwrap(); + PtrOpDef::load_all_ops(extension, extension_ref).unwrap(); + }) } lazy_static! { diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index 7e3f4f664..77c4ab990 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -321,16 +321,18 @@ pub(crate) mod test { const EXT_ID: ExtensionId = ExtensionId::new_unchecked("my_ext"); const TYPE_NAME: TypeName = TypeName::new_inline("MyType"); - let mut e = Extension::new_test(EXT_ID); - e.add_type( - TYPE_NAME, - vec![bound.clone()], - "".into(), - TypeDefBound::any(), - ) - .unwrap(); + let ext = Extension::new_test_arc(EXT_ID, |ext, extension_ref| { + ext.add_type( + TYPE_NAME, + vec![bound.clone()], + "".into(), + TypeDefBound::any(), + extension_ref, + ) + .unwrap(); + }); - let reg = ExtensionRegistry::try_new([e.into()]).unwrap(); + let reg = ExtensionRegistry::try_new([ext]).unwrap(); let make_scheme = |tp: TypeParam| { PolyFuncTypeBase::new_validated( diff --git a/hugr-core/src/utils.rs b/hugr-core/src/utils.rs index f1f97cc5a..702ad8c19 100644 --- a/hugr-core/src/utils.rs +++ b/hugr-core/src/utils.rs @@ -131,48 +131,60 @@ pub(crate) mod test_quantum_extension { /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("test.quantum"); fn extension() -> Arc { - let mut extension = Extension::new_test(EXTENSION_ID); - - extension - .add_op(OpName::new_inline("H"), "Hadamard".into(), one_qb_func()) - .unwrap(); - extension - .add_op( - OpName::new_inline("RzF64"), - "Rotation specified by float".into(), - Signature::new(type_row![QB_T, float_types::FLOAT64_TYPE], type_row![QB_T]), - ) - .unwrap(); - - extension - .add_op(OpName::new_inline("CX"), "CX".into(), two_qb_func()) - .unwrap(); - - extension - .add_op( - OpName::new_inline("Measure"), - "Measure a qubit, returning the qubit and the measurement result.".into(), - Signature::new(type_row![QB_T], type_row![QB_T, BOOL_T]), - ) - .unwrap(); - - extension - .add_op( - OpName::new_inline("QAlloc"), - "Allocate a new qubit.".into(), - Signature::new(type_row![], type_row![QB_T]), - ) - .unwrap(); - - extension - .add_op( - OpName::new_inline("QDiscard"), - "Discard a qubit.".into(), - Signature::new(type_row![QB_T], type_row![]), - ) - .unwrap(); - - Arc::new(extension) + Extension::new_test_arc(EXTENSION_ID, |extension, extension_ref| { + extension + .add_op( + OpName::new_inline("H"), + "Hadamard".into(), + one_qb_func(), + extension_ref, + ) + .unwrap(); + extension + .add_op( + OpName::new_inline("RzF64"), + "Rotation specified by float".into(), + Signature::new(type_row![QB_T, float_types::FLOAT64_TYPE], type_row![QB_T]), + extension_ref, + ) + .unwrap(); + + extension + .add_op( + OpName::new_inline("CX"), + "CX".into(), + two_qb_func(), + extension_ref, + ) + .unwrap(); + + extension + .add_op( + OpName::new_inline("Measure"), + "Measure a qubit, returning the qubit and the measurement result.".into(), + Signature::new(type_row![QB_T], type_row![QB_T, BOOL_T]), + extension_ref, + ) + .unwrap(); + + extension + .add_op( + OpName::new_inline("QAlloc"), + "Allocate a new qubit.".into(), + Signature::new(type_row![], type_row![QB_T]), + extension_ref, + ) + .unwrap(); + + extension + .add_op( + OpName::new_inline("QDiscard"), + "Discard a qubit.".into(), + Signature::new(type_row![QB_T], type_row![]), + extension_ref, + ) + .unwrap(); + }) } lazy_static! { diff --git a/hugr-llvm/src/custom/extension_op.rs b/hugr-llvm/src/custom/extension_op.rs index 08b392036..cd3c3b6e7 100644 --- a/hugr-llvm/src/custom/extension_op.rs +++ b/hugr-llvm/src/custom/extension_op.rs @@ -100,7 +100,7 @@ impl<'a, H: HugrView> ExtensionOpMap<'a, H> { args: EmitOpArgs<'c, '_, ExtensionOp, H>, ) -> Result<()> { let node = args.node(); - let key = (node.def().extension().clone(), node.def().name().clone()); + let key = (node.def().extension_id().clone(), node.def().name().clone()); let Some(handler) = self.0.get(&key) else { bail!("No extension could emit extension op: {key:?}") }; diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index a34ecc351..3ac3dacb7 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -157,6 +157,7 @@ fn mk_rep( #[cfg(test)] mod test { use std::collections::HashSet; + use std::sync::Arc; use hugr_core::extension::prelude::Lift; use itertools::Itertools; @@ -178,21 +179,26 @@ mod test { const EXT_ID: ExtensionId = "TestExt"; } - fn extension() -> Extension { - let mut e = Extension::new(EXT_ID, hugr_core::extension::Version::new(0, 1, 0)); - e.add_op( - "Test".into(), - String::new(), - Signature::new( - type_row![QB_T, USIZE_T], - TypeRow::from(vec![Type::new_sum(vec![ - type_row![QB_T], - type_row![USIZE_T], - ])]), - ), + fn extension() -> Arc { + Extension::new_arc( + EXT_ID, + hugr_core::extension::Version::new(0, 1, 0), + |ext, extension_ref| { + ext.add_op( + "Test".into(), + String::new(), + Signature::new( + type_row![QB_T, USIZE_T], + TypeRow::from(vec![Type::new_sum(vec![ + type_row![QB_T], + type_row![USIZE_T], + ])]), + ), + extension_ref, + ) + .unwrap(); + }, ) - .unwrap(); - e } fn lifted_unary_unit_sum + AsRef, T>(b: &mut DFGWrapper) -> Wire { @@ -228,7 +234,7 @@ mod test { let exit_types = type_row![USIZE_T]; let e = extension(); let tst_op = e.instantiate_extension_op("Test", [], &PRELUDE_REGISTRY)?; - let reg = ExtensionRegistry::try_new([PRELUDE.clone(), e.into()])?; + let reg = ExtensionRegistry::try_new([PRELUDE.clone(), e])?; let mut h = CFGBuilder::new(inout_sig(loop_variants.clone(), exit_types.clone()))?; let mut no_b1 = h.simple_entry_builder_exts(loop_variants.clone(), 1, PRELUDE_ID)?; let n = no_b1.add_dataflow_op(Noop::new(QB_T), no_b1.input_wires())?; @@ -299,7 +305,7 @@ mod test { // And the Noop in the entry block is consumed by the custom Test op let tst = find_unique( h.nodes(), - |n| matches!(h.get_optype(*n), OpType::ExtensionOp(c) if c.def().extension() != &PRELUDE_ID), + |n| matches!(h.get_optype(*n), OpType::ExtensionOp(c) if c.def().extension_id() != &PRELUDE_ID), ); assert_eq!(h.get_parent(tst), Some(entry)); assert_eq!( @@ -355,7 +361,7 @@ mod test { h.branch(&bb2, 0, &bb3)?; h.branch(&bb3, 0, &h.exit_block())?; - let reg = ExtensionRegistry::try_new([e.into(), PRELUDE.clone()])?; + let reg = ExtensionRegistry::try_new([e, PRELUDE.clone()])?; let mut h = h.finish_hugr(®)?; let root = h.root(); merge_basic_blocks(&mut SiblingMut::try_new(&mut h, root)?); @@ -365,7 +371,7 @@ mod test { let [bb, _exit] = h.children(h.root()).collect::>().try_into().unwrap(); let tst = find_unique( h.nodes(), - |n| matches!(h.get_optype(*n), OpType::ExtensionOp(c) if c.def().extension() != &PRELUDE_ID), + |n| matches!(h.get_optype(*n), OpType::ExtensionOp(c) if c.def().extension_id() != &PRELUDE_ID), ); assert_eq!(h.get_parent(tst), Some(bb)); diff --git a/hugr/benches/benchmarks/hugr/examples.rs b/hugr/benches/benchmarks/hugr/examples.rs index 3abcee535..48a336d2d 100644 --- a/hugr/benches/benchmarks/hugr/examples.rs +++ b/hugr/benches/benchmarks/hugr/examples.rs @@ -1,5 +1,7 @@ //! Builders and utilities for benchmarks. +use std::sync::Arc; + use hugr::builder::{ BuildError, CFGBuilder, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, @@ -53,35 +55,35 @@ pub fn simple_cfg_hugr() -> Hugr { } lazy_static! { - static ref QUANTUM_EXT: Extension = { - let mut extension = Extension::new( + static ref QUANTUM_EXT: Arc = { + Extension::new_arc( "bench.quantum".try_into().unwrap(), hugr::extension::Version::new(0, 0, 0), - ); - - extension - .add_op( - OpName::new_inline("H"), - "".into(), - Signature::new_endo(QB_T), - ) - .unwrap(); - extension - .add_op( - OpName::new_inline("Rz"), - "".into(), - Signature::new(type_row![QB_T, FLOAT64_TYPE], type_row![QB_T]), - ) - .unwrap(); - - extension - .add_op( - OpName::new_inline("CX"), - "".into(), - Signature::new_endo(type_row![QB_T, QB_T]), - ) - .unwrap(); - extension + |ext, extension_ref| { + ext.add_op( + OpName::new_inline("H"), + "".into(), + Signature::new_endo(QB_T), + extension_ref, + ) + .unwrap(); + ext.add_op( + OpName::new_inline("Rz"), + "".into(), + Signature::new(type_row![QB_T, FLOAT64_TYPE], type_row![QB_T]), + extension_ref, + ) + .unwrap(); + + ext.add_op( + OpName::new_inline("CX"), + "".into(), + Signature::new_endo(type_row![QB_T, QB_T]), + extension_ref, + ) + .unwrap(); + }, + ) }; } diff --git a/hugr/src/lib.rs b/hugr/src/lib.rs index 94dd141a3..d23063dfa 100644 --- a/hugr/src/lib.rs +++ b/hugr/src/lib.rs @@ -61,25 +61,21 @@ //! pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("mini.quantum"); //! pub const VERSION: Version = Version::new(0, 1, 0); //! fn extension() -> Arc { -//! let mut extension = Extension::new(EXTENSION_ID, VERSION); +//! Extension::new_arc(EXTENSION_ID, VERSION, |ext, extension_ref| { +//! ext.add_op(OpName::new_inline("H"), "Hadamard".into(), one_qb_func(), extension_ref) +//! .unwrap(); //! -//! extension -//! .add_op(OpName::new_inline("H"), "Hadamard".into(), one_qb_func()) -//! .unwrap(); -//! -//! extension -//! .add_op(OpName::new_inline("CX"), "CX".into(), two_qb_func()) -//! .unwrap(); +//! ext.add_op(OpName::new_inline("CX"), "CX".into(), two_qb_func(), extension_ref) +//! .unwrap(); //! -//! extension -//! .add_op( +//! ext.add_op( //! OpName::new_inline("Measure"), //! "Measure a qubit, returning the qubit and the measurement result.".into(), //! FuncValueType::new(type_row![QB_T], type_row![QB_T, BOOL_T]), +//! extension_ref, //! ) //! .unwrap(); -//! -//! Arc::new(extension) +//! }) //! } //! //! lazy_static! { From fc609a202353c1c01eec4f316f919b2d39ffcd85 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 27 Nov 2024 14:13:45 +0000 Subject: [PATCH 3/3] feat: add HugrView::first_child and HugrMut::remove_subtree (#1721) * Clarify doc of HugrMut::remove_node * Add HugrView::first_child, we have this anyway and it's useful because Rust's non-lexical lifetimes don't go far enough * Add HugrMut::remove_subtree, test closes #1663 --- hugr-core/src/hugr/hugrmut.rs | 46 ++++++++++++++++++++++++++++++++++- hugr-core/src/hugr/rewrite.rs | 5 +--- hugr-core/src/hugr/views.rs | 6 +++++ 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index 3c538c357..3d9edc050 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -119,6 +119,8 @@ pub trait HugrMut: HugrMutInternals { } /// Remove a node from the graph and return the node weight. + /// Note that if the node has children, they are not removed; this leaves + /// the Hugr in an invalid state. See [Self::remove_subtree]. /// /// # Panics /// @@ -129,6 +131,19 @@ pub trait HugrMut: HugrMutInternals { self.hugr_mut().remove_node(node) } + /// Remove a node from the graph, along with all its descendants in the hierarchy. + /// + /// # Panics + /// + /// If the node is not in the graph, or is the root (this would leave an empty Hugr). + fn remove_subtree(&mut self, node: Node) { + panic_invalid_non_root(self, node); + while let Some(ch) = self.first_child(node) { + self.remove_subtree(ch) + } + self.hugr_mut().remove_node(node); + } + /// Connect two nodes at the given ports. /// /// # Panics @@ -524,7 +539,7 @@ mod test { PRELUDE_REGISTRY, }, macros::type_row, - ops::{self, dataflow::IOTrait}, + ops::{self, dataflow::IOTrait, FuncDefn, Input, Output}, types::{Signature, Type}, }; @@ -583,4 +598,33 @@ mod test { hugr.remove_metadata(root, "meta"); assert_eq!(hugr.get_metadata(root, "meta"), None); } + + #[test] + fn remove_subtree() { + let mut hugr = Hugr::default(); + let root = hugr.root(); + let [foo, bar] = ["foo", "bar"].map(|name| { + let fd = hugr.add_node_with_parent( + root, + FuncDefn { + name: name.to_string(), + signature: Signature::new_endo(NAT).into(), + }, + ); + let inp = hugr.add_node_with_parent(fd, Input::new(NAT)); + let out = hugr.add_node_with_parent(fd, Output::new(NAT)); + hugr.connect(inp, 0, out, 0); + fd + }); + hugr.validate(&PRELUDE_REGISTRY).unwrap(); + assert_eq!(hugr.node_count(), 7); + + hugr.remove_subtree(foo); + hugr.validate(&PRELUDE_REGISTRY).unwrap(); + assert_eq!(hugr.node_count(), 4); + + hugr.remove_subtree(bar); + hugr.validate(&PRELUDE_REGISTRY).unwrap(); + assert_eq!(hugr.node_count(), 1); + } } diff --git a/hugr-core/src/hugr/rewrite.rs b/hugr-core/src/hugr/rewrite.rs index dd26b1ac2..3354fc820 100644 --- a/hugr-core/src/hugr/rewrite.rs +++ b/hugr-core/src/hugr/rewrite.rs @@ -70,14 +70,11 @@ impl Rewrite for Transactional { let mut backup = Hugr::new(h.root_type().clone()); backup.insert_from_view(backup.root(), h); let r = self.underlying.apply(h); - fn first_child(h: &impl HugrView) -> Option { - h.children(h.root()).next() - } if r.is_err() { // Try to restore backup. h.replace_op(h.root(), backup.root_type().clone()) .expect("The root replacement should always match the old root type"); - while let Some(child) = first_child(h) { + while let Some(child) = h.first_child(h.root()) { h.remove_node(child); } h.insert_from_view(h.root(), &backup); diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index 7d744c150..442625e33 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -261,6 +261,12 @@ pub trait HugrView: HugrInternals { /// Return iterator over the direct children of node. fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone; + /// Returns the first child of the specified node (if it is a parent). + /// Useful because `x.children().next()` leaves x borrowed. + fn first_child(&self, node: Node) -> Option { + self.children(node).next() + } + /// Iterates over neighbour nodes in the given direction. /// May contain duplicates if the graph has multiple links between nodes. fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone;