From 2565f732123c3b96a48629d7c3bdeb886365bbc7 Mon Sep 17 00:00:00 2001 From: Agustin Borgna <agustin.borgna@quantinuum.com> Date: Tue, 27 Aug 2024 12:34:01 +0100 Subject: [PATCH 1/4] Bring the collections ext in line with other extension defs --- hugr-core/src/hugr/rewrite/replace.rs | 6 +- hugr-core/src/hugr/validate/test.rs | 2 +- hugr-core/src/std_extensions/collections.rs | 270 ++++++++++-------- .../std_extensions/collections/list_fold.rs | 49 ++++ hugr-passes/src/const_fold/test.rs | 4 +- 5 files changed, 207 insertions(+), 124 deletions(-) create mode 100644 hugr-core/src/std_extensions/collections/list_fold.rs diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/rewrite/replace.rs index 837127f7c..9c8ce6731 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/rewrite/replace.rs @@ -524,21 +524,21 @@ mod test { inputs: vec![listy.clone()].into(), sum_rows: vec![type_row![]], other_outputs: vec![listy.clone()].into(), - extension_delta: collections::EXTENSION_NAME.into(), + extension_delta: collections::EXTENSION_ID.into(), }, ); let r_df1 = replacement.add_node_with_parent( r_bb, DFG { signature: Signature::new(vec![listy.clone()], simple_unary_plus(intermed.clone())) - .with_extension_delta(collections::EXTENSION_NAME), + .with_extension_delta(collections::EXTENSION_ID), }, ); let r_df2 = replacement.add_node_with_parent( r_bb, DFG { signature: Signature::new(intermed, simple_unary_plus(just_list.clone())) - .with_extension_delta(collections::EXTENSION_NAME), + .with_extension_delta(collections::EXTENSION_ID), }, ); [0, 1] diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index a5e901a3d..cc30ec7fc 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -547,7 +547,7 @@ fn no_polymorphic_consts() -> Result<(), Box<dyn std::error::Error>> { PolyFuncType::new( [BOUND], Signature::new(vec![], vec![list_of_var.clone()]) - .with_extension_delta(collections::EXTENSION_NAME), + .with_extension_delta(collections::EXTENSION_ID), ), )?; let empty_list = Value::extension(collections::ListValue::new_empty(Type::new_var_use( diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index 480780496..44c82b347 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -1,36 +1,37 @@ //! List type and operations. +mod list_fold; + +use std::str::FromStr; + use itertools::Itertools; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; +use strum_macros::{EnumIter, EnumString, IntoStaticStr}; +use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; +use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc, PRELUDE}; use crate::ops::constant::ValueName; use crate::ops::{OpName, Value}; -use crate::types::TypeName; +use crate::types::{TypeName, TypeRowRV}; use crate::{ extension::{ simple_op::{MakeExtensionOp, OpLoadError}, - ConstFold, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef, - TypeDefBound, + ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef, TypeDefBound, }, ops::constant::CustomConst, - ops::{self, custom::ExtensionOp, NamedOp}, + ops::{custom::ExtensionOp, NamedOp}, types::{ type_param::{TypeArg, TypeParam}, CustomCheckFailure, CustomType, FuncValueType, PolyFuncTypeRV, Type, TypeBound, }, - utils::sorted_consts, Extension, }; /// Reported unique name of the list type. pub const LIST_TYPENAME: TypeName = TypeName::new_inline("List"); -/// Pop operation name. -pub const POP_NAME: OpName = OpName::new_inline("pop"); -/// Push operation name. -pub const PUSH_NAME: OpName = OpName::new_inline("push"); /// Reported unique name of the extension -pub const EXTENSION_NAME: ExtensionId = ExtensionId::new_unchecked("collections"); +pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections"); /// Extension version. pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); @@ -100,92 +101,159 @@ impl CustomConst for ListValue { fn extension_reqs(&self) -> ExtensionSet { ExtensionSet::union_over(self.0.iter().map(Value::extension_reqs)) - .union(EXTENSION_NAME.into()) + .union(EXTENSION_ID.into()) } } -struct PopFold; +/// A list operation +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] +#[allow(non_camel_case_types)] +#[non_exhaustive] +pub enum ListOp { + /// Pop from end of list + pop, + /// Push to end of list + push, +} + +impl ListOp { + /// Type parameter used in the list types. + const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; + + /// Instantiate a list operation with an `element_type` + pub fn with_type(self, element_type: Type) -> ListOpInst { + ListOpInst { + elem_type: element_type, + op: self, + } + } + + /// Compute the signature of the operation, given the list type definition. + fn compute_signature(self, list_type_def: &TypeDef) -> SignatureFunc { + use ListOp::*; + let e = self.elem_type(); + let l = self.list_type(list_type_def); + match self { + pop => self + .list_polytype(vec![l.clone()], vec![l.clone(), e.clone()]) + .into(), + push => self.list_polytype(vec![l.clone(), e], vec![l]).into(), + } + } -impl ConstFold for PopFold { - fn fold( - &self, - _type_args: &[TypeArg], - consts: &[(crate::IncomingPort, ops::Value)], - ) -> crate::extension::ConstFoldResult { - let [list]: [&ops::Value; 1] = sorted_consts(consts).try_into().ok()?; - let list: &ListValue = list.get_custom_value().expect("Should be list value."); - let mut list = list.clone(); - let elem = list.0.pop()?; // empty list fails to evaluate "pop" + /// Compute a polymorphic function type for a list operation. + fn list_polytype( + self, + input: impl Into<TypeRowRV>, + output: impl Into<TypeRowRV>, + ) -> PolyFuncTypeRV { + PolyFuncTypeRV::new(vec![Self::TP], FuncValueType::new(input, output)) + } + + /// Returns a generic unbounded type for a list element. + fn elem_type(self) -> Type { + Type::new_var_use(0, TypeBound::Any) + } - Some(vec![(0.into(), list.into()), (1.into(), elem)]) + /// Returns the type of a generic list. + fn list_type(self, list_type_def: &TypeDef) -> Type { + Type::new_extension( + list_type_def + .instantiate(vec![TypeArg::new_var_use(0, Self::TP)]) + .unwrap(), + ) } } -struct PushFold; +impl MakeOpDef for ListOp { + fn from_def(op_def: &OpDef) -> Result<Self, crate::extension::simple_op::OpLoadError> { + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + } + + fn extension(&self) -> ExtensionId { + EXTENSION_ID.to_owned() + } + + /// Add an operation implemented as an [MakeOpDef], which can provide the data + /// required to define an [OpDef], to an extension. + // + // This method is re-defined here since we need to pass the list type def while computing the signature, + // to avoid recursive loops initializing the extension. + fn add_to_extension(&self, extension: &mut Extension) -> Result<(), ExtensionBuildError> { + let sig = self.compute_signature(extension.get_type(&LIST_TYPENAME).unwrap()); + let def = extension.add_op(self.name(), self.description(), sig)?; -impl ConstFold for PushFold { - fn fold( - &self, - _type_args: &[TypeArg], - consts: &[(crate::IncomingPort, ops::Value)], - ) -> crate::extension::ConstFoldResult { - let [list, elem]: [&ops::Value; 2] = sorted_consts(consts).try_into().ok()?; - let list: &ListValue = list.get_custom_value().expect("Should be list value."); - let mut list = list.clone(); - list.0.push(elem.clone()); + self.post_opdef(def); + + Ok(()) + } + + fn signature(&self) -> SignatureFunc { + self.compute_signature(list_type_def()) + } + + fn description(&self) -> String { + use ListOp::*; + + match self { + pop => "Pop from back of list", + push => "Push to back of list", + } + .into() + } - Some(vec![(0.into(), list.into())]) + fn post_opdef(&self, def: &mut OpDef) { + list_fold::set_fold(self, def) } } -const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; -fn extension() -> Extension { - let mut extension = Extension::new(EXTENSION_NAME, VERSION); +lazy_static! { + /// Extension for basic float operations. + pub static ref EXTENSION: Extension = { + println!("creating collections extension"); + let mut extension = Extension::new(EXTENSION_ID, VERSION); - extension - .add_type( + // The list type must be defined before the operations are added. + extension.add_type( LIST_TYPENAME, - vec![TP], + vec![ListOp::TP], "Generic dynamically sized list of type T.".into(), TypeDefBound::from_params(vec![0]), ) .unwrap(); - let list_type_def = extension.get_type(&LIST_TYPENAME).unwrap(); - - let (l, e) = list_and_elem_type_vars(list_type_def); - extension - .add_op( - POP_NAME, - "Pop from back of list".into(), - PolyFuncTypeRV::new( - vec![TP], - FuncValueType::new(vec![l.clone()], vec![l.clone(), e.clone()]), - ), - ) - .unwrap() - .set_constant_folder(PopFold); - extension - .add_op( - PUSH_NAME, - "Push to back of list".into(), - PolyFuncTypeRV::new(vec![TP], FuncValueType::new(vec![l.clone(), e], vec![l])), - ) - .unwrap() - .set_constant_folder(PushFold); - extension + ListOp::load_all_ops(&mut extension).unwrap(); + + extension + }; + + /// Registry of extensions required to validate float operations. + pub static ref COLLECTIONS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + EXTENSION.to_owned(), + ]) + .unwrap(); } -lazy_static! { - /// Collections extension definition. - pub static ref EXTENSION: Extension = extension(); +impl MakeRegisteredOp for ListOp { + fn extension_id(&self) -> ExtensionId { + EXTENSION_ID.to_owned() + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { + &COLLECTIONS_REGISTRY + } +} + +/// Get the type of a list of `elem_type` as a `CustomType`. +pub fn list_type_def() -> &'static TypeDef { + // This must not be called while the extension is being built. + EXTENSION.get_type(&LIST_TYPENAME).unwrap() } /// Get the type of a list of `elem_type` as a `CustomType`. pub fn list_custom_type(elem_type: Type) -> CustomType { - EXTENSION - .get_type(&LIST_TYPENAME) - .unwrap() + list_type_def() .instantiate(vec![TypeArg::Type { ty: elem_type }]) .unwrap() } @@ -195,37 +263,9 @@ pub fn list_type(elem_type: Type) -> Type { list_custom_type(elem_type).into() } -fn list_and_elem_type_vars(list_type_def: &TypeDef) -> (Type, Type) { - let elem_type = Type::new_var_use(0, TypeBound::Any); - let list_type = Type::new_extension( - list_type_def - .instantiate(vec![TypeArg::new_var_use(0, TP)]) - .unwrap(), - ); - (list_type, elem_type) -} - -/// A list operation -#[derive(Debug, Clone, PartialEq)] -#[non_exhaustive] -pub enum ListOp { - /// Pop from end of list - Pop, - /// Push to end of list - Push, -} - -impl ListOp { - /// Instantiate a list operation with an `element_type` - pub fn with_type(self, element_type: Type) -> ListOpInst { - ListOpInst { - elem_type: element_type, - op: self, - } - } -} - /// A list operation with a concrete element type. +/// +/// See [ListOp] for the parametric version. #[derive(Debug, Clone, PartialEq)] pub struct ListOpInst { op: ListOp, @@ -234,10 +274,8 @@ pub struct ListOpInst { impl NamedOp for ListOpInst { fn name(&self) -> OpName { - match self.op { - ListOp::Pop => POP_NAME, - ListOp::Push => PUSH_NAME, - } + let name: &str = self.op.into(); + name.into() } } @@ -249,11 +287,8 @@ impl MakeExtensionOp for ListOpInst { return Err(SignatureError::InvalidTypeArgs.into()); }; let name = ext_op.def().name(); - let op = match name { - // can't use const SmolStr in pattern - _ if name == &POP_NAME => ListOp::Pop, - _ if name == &PUSH_NAME => ListOp::Push, - _ => return Err(OpLoadError::NotMember(name.to_string())), + let Ok(op) = ListOp::from_str(name) else { + return Err(OpLoadError::NotMember(name.to_string())); }; Ok(Self { @@ -283,7 +318,7 @@ impl ListOpInst { ) .unwrap(); ExtensionOp::new( - registry.get(&EXTENSION_NAME)?.get_op(&self.name())?.clone(), + registry.get(&EXTENSION_ID)?.get_op(&self.name())?.clone(), self.type_args(), ®istry, ) @@ -307,16 +342,15 @@ mod test { #[test] fn test_extension() { - let r: Extension = extension(); - assert_eq!(r.name(), &EXTENSION_NAME); - let ops = r.operations(); - assert_eq!(ops.count(), 2); + assert_eq!(EXTENSION.name(), &EXTENSION_ID); + for (_, op_def) in EXTENSION.operations() { + assert_eq!(op_def.extension(), &EXTENSION_ID); + } } #[test] fn test_list() { - let r: Extension = extension(); - let list_def = r.get_type(&LIST_TYPENAME).unwrap(); + let list_def = list_type_def(); let list_type = list_def .instantiate([TypeArg::Type { ty: USIZE_T }]) @@ -340,7 +374,7 @@ mod test { let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]) .unwrap(); - let pop_op = ListOp::Pop.with_type(QB_T); + let pop_op = ListOp::pop.with_type(QB_T); let pop_ext = pop_op.clone().to_extension_op(®).unwrap(); assert_eq!(ListOpInst::from_extension_op(&pop_ext).unwrap(), pop_op); let pop_sig = pop_ext.dataflow_signature().unwrap(); @@ -352,7 +386,7 @@ mod test { assert_eq!(pop_sig.input(), &just_list_row); assert_eq!(pop_sig.output(), &both_row); - let push_op = ListOp::Push.with_type(FLOAT64_TYPE); + let push_op = ListOp::push.with_type(FLOAT64_TYPE); let push_ext = push_op.clone().to_extension_op(®).unwrap(); assert_eq!(ListOpInst::from_extension_op(&push_ext).unwrap(), push_op); let push_sig = push_ext.dataflow_signature().unwrap(); diff --git a/hugr-core/src/std_extensions/collections/list_fold.rs b/hugr-core/src/std_extensions/collections/list_fold.rs new file mode 100644 index 000000000..ca9b78c77 --- /dev/null +++ b/hugr-core/src/std_extensions/collections/list_fold.rs @@ -0,0 +1,49 @@ +//! Folding definitions for list operations. + +use crate::extension::{ConstFold, OpDef}; +use crate::ops; +use crate::types::type_param::TypeArg; +use crate::utils::sorted_consts; + +use super::{ListOp, ListValue}; + +pub(super) fn set_fold(op: &ListOp, def: &mut OpDef) { + match op { + ListOp::pop => def.set_constant_folder(PopFold), + ListOp::push => def.set_constant_folder(PushFold), + } +} + +pub struct PopFold; + +impl ConstFold for PopFold { + fn fold( + &self, + _type_args: &[TypeArg], + consts: &[(crate::IncomingPort, ops::Value)], + ) -> crate::extension::ConstFoldResult { + let [list]: [&ops::Value; 1] = sorted_consts(consts).try_into().ok()?; + let list: &ListValue = list.get_custom_value().expect("Should be list value."); + let mut list = list.clone(); + let elem = list.0.pop()?; // empty list fails to evaluate "pop" + + Some(vec![(0.into(), list.into()), (1.into(), elem)]) + } +} + +pub struct PushFold; + +impl ConstFold for PushFold { + fn fold( + &self, + _type_args: &[TypeArg], + consts: &[(crate::IncomingPort, ops::Value)], + ) -> crate::extension::ConstFoldResult { + let [list, elem]: [&ops::Value; 2] = sorted_consts(consts).try_into().ok()?; + let list: &ListValue = list.get_custom_value().expect("Should be list value."); + let mut list = list.clone(); + list.0.push(elem.clone()); + + Some(vec![(0.into(), list.into())]) + } +} diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 28f43f31e..0898e93d5 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -143,12 +143,12 @@ fn test_list_ops() -> Result<(), Box<dyn std::error::Error>> { let list_wire = build.add_load_const(list.clone()); let pop = build.add_dataflow_op( - ListOp::Pop.with_type(BOOL_T).to_extension_op(®).unwrap(), + ListOp::pop.with_type(BOOL_T).to_extension_op(®).unwrap(), [list_wire], )?; let push = build.add_dataflow_op( - ListOp::Push + ListOp::push .with_type(BOOL_T) .to_extension_op(®) .unwrap(), From c6758450f9fccbecd4efa2f5238a287f3be40f53 Mon Sep 17 00:00:00 2001 From: Agustin Borgna <agustin.borgna@quantinuum.com> Date: Tue, 27 Aug 2024 14:30:31 +0100 Subject: [PATCH 2/4] fix docs --- hugr-core/src/std_extensions/collections.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index 44c82b347..98f5741b0 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -208,7 +208,7 @@ impl MakeOpDef for ListOp { } lazy_static! { - /// Extension for basic float operations. + /// Extension for list operations. pub static ref EXTENSION: Extension = { println!("creating collections extension"); let mut extension = Extension::new(EXTENSION_ID, VERSION); @@ -227,7 +227,7 @@ lazy_static! { extension }; - /// Registry of extensions required to validate float operations. + /// Registry of extensions required to validate list operations. pub static ref COLLECTIONS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ PRELUDE.to_owned(), EXTENSION.to_owned(), From e72e8dd0ecc10e14fd083164ac467f7abc65a2a6 Mon Sep 17 00:00:00 2001 From: Agustin Borgna <agustin.borgna@quantinuum.com> Date: Tue, 27 Aug 2024 15:49:44 +0100 Subject: [PATCH 3/4] Cleanup unneeded method --- hugr-core/src/std_extensions/collections.rs | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index 98f5741b0..567ccc73b 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -131,8 +131,8 @@ impl ListOp { /// Compute the signature of the operation, given the list type definition. fn compute_signature(self, list_type_def: &TypeDef) -> SignatureFunc { use ListOp::*; - let e = self.elem_type(); - let l = self.list_type(list_type_def); + let e = Type::new_var_use(0, TypeBound::Any); + let l = self.list_type(list_type_def, 0); match self { pop => self .list_polytype(vec![l.clone()], vec![l.clone(), e.clone()]) @@ -150,16 +150,11 @@ impl ListOp { PolyFuncTypeRV::new(vec![Self::TP], FuncValueType::new(input, output)) } - /// Returns a generic unbounded type for a list element. - fn elem_type(self) -> Type { - Type::new_var_use(0, TypeBound::Any) - } - - /// Returns the type of a generic list. - fn list_type(self, list_type_def: &TypeDef) -> Type { + /// Returns the type of a generic list, associated with the element type parameter at index `idx`. + fn list_type(self, list_type_def: &TypeDef, idx: usize) -> Type { Type::new_extension( list_type_def - .instantiate(vec![TypeArg::new_var_use(0, Self::TP)]) + .instantiate(vec![TypeArg::new_var_use(idx, Self::TP)]) .unwrap(), ) } From 9c646b346b70cbf8694ed0af03dddfbac2fa49c1 Mon Sep 17 00:00:00 2001 From: Agustin Borgna <agustin.borgna@quantinuum.com> Date: Tue, 27 Aug 2024 16:42:22 +0100 Subject: [PATCH 4/4] Better test coverage --- hugr-core/src/std_extensions/collections.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index 567ccc73b..67178c076 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -337,7 +337,9 @@ mod test { #[test] fn test_extension() { - assert_eq!(EXTENSION.name(), &EXTENSION_ID); + assert_eq!(&ListOp::push.extension_id(), EXTENSION.name()); + assert_eq!(&ListOp::push.extension(), EXTENSION.name()); + assert!(ListOp::pop.registry().contains(EXTENSION.name())); for (_, op_def) in EXTENSION.operations() { assert_eq!(op_def.extension(), &EXTENSION_ID); }