diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 276bd7025..a40415be1 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -539,7 +539,7 @@ pub(super) mod test { use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE}; use crate::extension::{SignatureError, EMPTY_REG, PRELUDE_REGISTRY}; use crate::ops::OpName; - use crate::std_extensions::collections::{EXTENSION, LIST_TYPENAME}; + use crate::std_extensions::collections::list; use crate::types::type_param::{TypeArgError, TypeParam}; use crate::types::{PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV}; use crate::{const_extension_ids, Extension}; @@ -636,7 +636,7 @@ pub(super) mod test { #[test] fn op_def_with_type_scheme() -> Result<(), Box> { - let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap(); + let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap(); const OP_NAME: OpName = OpName::new_inline("Reverse"); let ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { @@ -658,7 +658,7 @@ pub(super) mod test { Ok(()) })?; - let reg = ExtensionRegistry::new([PRELUDE.clone(), EXTENSION.clone(), ext]); + let reg = ExtensionRegistry::new([PRELUDE.clone(), list::EXTENSION.clone(), ext]); reg.validate()?; let e = reg.get(&EXT_ID).unwrap(); diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/rewrite/replace.rs index f652e2170..2f3a82c04 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/rewrite/replace.rs @@ -456,7 +456,7 @@ mod test { use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle}; use crate::ops::{self, Case, DataflowBlock, OpTag, OpType, DFG}; - use crate::std_extensions::collections::{self, list_type, ListOp}; + use crate::std_extensions::collections::list; use crate::types::{Signature, Type, TypeRow}; use crate::utils::{depth, test_quantum_extension}; use crate::{type_row, Direction, Extension, Hugr, HugrView, OutgoingPort}; @@ -466,14 +466,14 @@ mod test { #[test] #[ignore] // FIXME: This needs a rewrite now that `pop` returns an optional value -.-' fn cfg() -> Result<(), Box> { - let reg = ExtensionRegistry::new([PRELUDE.to_owned(), collections::EXTENSION.to_owned()]); + let reg = ExtensionRegistry::new([PRELUDE.to_owned(), list::EXTENSION.to_owned()]); reg.validate()?; - let listy = list_type(usize_t()); - let pop: ExtensionOp = ListOp::pop + let listy = list::list_type(usize_t()); + let pop: ExtensionOp = list::ListOp::pop .with_type(usize_t()) .to_extension_op(®) .unwrap(); - let push: ExtensionOp = ListOp::push + let push: ExtensionOp = list::ListOp::push .with_type(usize_t()) .to_extension_op(®) .unwrap(); @@ -518,21 +518,21 @@ mod test { inputs: vec![listy.clone()].into(), sum_rows: vec![type_row![]], other_outputs: vec![listy.clone()].into(), - extension_delta: collections::EXTENSION_ID.into(), + extension_delta: list::EXTENSION_ID.into(), }, ); let r_df1 = replacement.add_node_with_parent( r_bb, DFG { signature: Signature::new(vec![listy.clone()], simple_unary_plus(intermed.clone())) - .with_extension_delta(collections::EXTENSION_ID), + .with_extension_delta(list::EXTENSION_ID), }, ); let r_df2 = replacement.add_node_with_parent( r_bb, DFG { signature: Signature::new(intermed, simple_unary_plus(just_list.clone())) - .with_extension_delta(collections::EXTENSION_ID), + .with_extension_delta(list::EXTENSION_ID), }, ); [0, 1] diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index dd4497aa7..fea6b336e 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -555,27 +555,27 @@ fn nested_typevars() -> Result<(), Box> { #[test] fn no_polymorphic_consts() -> Result<(), Box> { - use crate::std_extensions::collections; + use crate::std_extensions::collections::list; const BOUND: TypeParam = TypeParam::Type { b: TypeBound::Copyable, }; let list_of_var = Type::new_extension( - collections::EXTENSION - .get_type(&collections::LIST_TYPENAME) + list::EXTENSION + .get_type(&list::LIST_TYPENAME) .unwrap() .instantiate(vec![TypeArg::new_var_use(0, BOUND)])?, ); - let reg = ExtensionRegistry::new([collections::EXTENSION.to_owned()]); + let reg = ExtensionRegistry::new([list::EXTENSION.to_owned()]); reg.validate()?; let mut def = FunctionBuilder::new( "myfunc", PolyFuncType::new( [BOUND], Signature::new(vec![], vec![list_of_var.clone()]) - .with_extension_delta(collections::EXTENSION_ID), + .with_extension_delta(list::EXTENSION_ID), ), )?; - let empty_list = Value::extension(collections::ListValue::new_empty(Type::new_var_use( + let empty_list = Value::extension(list::ListValue::new_empty(Type::new_var_use( 0, TypeBound::Copyable, ))); diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 844ffb159..fe145d6a7 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -813,7 +813,7 @@ mod test { use crate::{ ops::{constant::CustomSerialized, Value}, std_extensions::arithmetic::int_types::ConstInt, - std_extensions::collections::ListValue, + std_extensions::collections::list::ListValue, types::{SumType, Type}, }; use ::proptest::{collection::vec, prelude::*}; diff --git a/hugr-core/src/ops/constant/custom.rs b/hugr-core/src/ops/constant/custom.rs index f9543cfef..4a59ab9b6 100644 --- a/hugr-core/src/ops/constant/custom.rs +++ b/hugr-core/src/ops/constant/custom.rs @@ -61,7 +61,7 @@ pub trait CustomConst: /// The extension(s) defining the custom constant /// (a set to allow, say, a [List] of [USize]) /// - /// [List]: crate::std_extensions::collections::LIST_TYPENAME + /// [List]: crate::std_extensions::collections::list::LIST_TYPENAME /// [USize]: crate::extension::prelude::usize_t fn extension_reqs(&self) -> ExtensionSet; @@ -362,7 +362,7 @@ mod test { use crate::{ extension::prelude::{usize_t, ConstUsize}, ops::{constant::custom::serialize_custom_const, Value}, - std_extensions::collections::ListValue, + std_extensions::collections::list::ListValue, }; use super::{super::OpaqueValue, CustomConst, CustomConstBoxClone, CustomSerialized}; diff --git a/hugr-core/src/std_extensions.rs b/hugr-core/src/std_extensions.rs index a2bc40ed0..2437b999c 100644 --- a/hugr-core/src/std_extensions.rs +++ b/hugr-core/src/std_extensions.rs @@ -18,7 +18,7 @@ pub fn std_reg() -> ExtensionRegistry { arithmetic::conversions::EXTENSION.to_owned(), arithmetic::float_ops::EXTENSION.to_owned(), arithmetic::float_types::EXTENSION.to_owned(), - collections::EXTENSION.to_owned(), + collections::list::EXTENSION.to_owned(), logic::EXTENSION.to_owned(), ptr::EXTENSION.to_owned(), ]); diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index 384c38a50..894c3a217 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -1,544 +1,3 @@ //! List type and operations. -use std::hash::{Hash, Hasher}; - -mod list_fold; - -use std::str::FromStr; -use std::sync::{Arc, Weak}; - -use itertools::Itertools; -use lazy_static::lazy_static; -use serde::{Deserialize, Serialize}; -use strum_macros::{EnumIter, EnumString, IntoStaticStr}; - -use crate::extension::prelude::{either_type, option_type, usize_t}; -use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; -use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc, PRELUDE}; -use crate::ops::constant::{maybe_hash_values, TryHash, ValueName}; -use crate::ops::{OpName, Value}; -use crate::types::{TypeName, TypeRowRV}; -use crate::{ - extension::{ - simple_op::{MakeExtensionOp, OpLoadError}, - ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef, TypeDefBound, - }, - ops::constant::CustomConst, - ops::{custom::ExtensionOp, NamedOp}, - types::{ - type_param::{TypeArg, TypeParam}, - CustomCheckFailure, CustomType, FuncValueType, PolyFuncTypeRV, Type, TypeBound, - }, - Extension, -}; - -/// Reported unique name of the list type. -pub const LIST_TYPENAME: TypeName = TypeName::new_inline("List"); -/// Reported unique name of the extension -pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections"); -/// Extension version. -pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -/// Dynamically sized list of values, all of the same type. -pub struct ListValue(Vec, Type); - -impl ListValue { - /// Create a new [CustomConst] for a list of values of type `typ`. - /// That all values ore of type `typ` is not checked here. - pub fn new(typ: Type, contents: impl IntoIterator) -> Self { - Self(contents.into_iter().collect_vec(), typ) - } - - /// Create a new [CustomConst] for an empty list of values of type `typ`. - pub fn new_empty(typ: Type) -> Self { - Self(vec![], typ) - } - - /// Returns the type of the `[ListValue]` as a `[CustomType]`.` - pub fn custom_type(&self) -> CustomType { - list_custom_type(self.1.clone()) - } - - /// Returns the type of values inside the `[ListValue]`. - pub fn get_element_type(&self) -> &Type { - &self.1 - } - - /// Returns the values contained inside the `[ListValue]`. - pub fn get_contents(&self) -> &[Value] { - &self.0 - } -} - -impl TryHash for ListValue { - fn try_hash(&self, mut st: &mut dyn Hasher) -> bool { - maybe_hash_values(&self.0, &mut st) && { - self.1.hash(&mut st); - true - } - } -} - -#[typetag::serde] -impl CustomConst for ListValue { - fn name(&self) -> ValueName { - ValueName::new_inline("list") - } - - fn get_type(&self) -> Type { - self.custom_type().into() - } - - fn validate(&self) -> Result<(), CustomCheckFailure> { - let typ = self.custom_type(); - let error = || { - // TODO more bespoke errors - CustomCheckFailure::Message("List type check fail.".to_string()) - }; - - EXTENSION - .get_type(&LIST_TYPENAME) - .unwrap() - .check_custom(&typ) - .map_err(|_| error())?; - - // constant can only hold classic type. - let [TypeArg::Type { ty }] = typ.args() else { - return Err(error()); - }; - - // check all values are instances of the element type - for v in &self.0 { - if v.get_type() != *ty { - return Err(error()); - } - } - - Ok(()) - } - - fn equal_consts(&self, other: &dyn CustomConst) -> bool { - crate::ops::constant::downcast_equal_consts(self, other) - } - - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::union_over(self.0.iter().map(Value::extension_reqs)) - .union(EXTENSION_ID.into()) - } -} - -/// A list operation -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] -#[allow(non_camel_case_types)] -#[non_exhaustive] -pub enum ListOp { - /// Pop from the end of list. Return an optional value. - pop, - /// Push to end of list. Return the new list. - push, - /// Lookup an element in a list by index. - get, - /// Replace the element at index `i` with value `v`, and return the old value. - /// - /// If the index is out of bounds, returns the input value as an error. - set, - /// Insert an element at index `i`. - /// - /// Elements at higher indices are shifted one position to the right. - /// Returns an Err with the element if the index is out of bounds. - insert, - /// Get the length of a list. - length, -} - -impl ListOp { - /// Type parameter used in the list types. - const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; - - /// Instantiate a list operation with an `element_type`. - pub fn with_type(self, element_type: Type) -> ListOpInst { - ListOpInst { - elem_type: element_type, - op: self, - } - } - - /// Compute the signature of the operation, given the list type definition. - fn compute_signature(self, list_type_def: &TypeDef) -> SignatureFunc { - use ListOp::*; - let e = Type::new_var_use(0, TypeBound::Any); - let l = self.list_type(list_type_def, 0); - match self { - pop => self - .list_polytype(vec![l.clone()], vec![l, Type::from(option_type(e))]) - .into(), - push => self.list_polytype(vec![l.clone(), e], vec![l]).into(), - get => self - .list_polytype(vec![l, usize_t()], vec![Type::from(option_type(e))]) - .into(), - set => self - .list_polytype( - vec![l.clone(), usize_t(), e.clone()], - vec![l, Type::from(either_type(e.clone(), e))], - ) - .into(), - insert => self - .list_polytype( - vec![l.clone(), usize_t(), e.clone()], - vec![l, either_type(e, Type::UNIT).into()], - ) - .into(), - length => self - .list_polytype(vec![l.clone()], vec![l, usize_t()]) - .into(), - } - } - - /// Compute a polymorphic function type for a list operation. - fn list_polytype( - self, - input: impl Into, - output: impl Into, - ) -> PolyFuncTypeRV { - PolyFuncTypeRV::new(vec![Self::TP], FuncValueType::new(input, output)) - } - - /// Returns the type of a generic list, associated with the element type parameter at index `idx`. - fn list_type(self, list_type_def: &TypeDef, idx: usize) -> Type { - Type::new_extension( - list_type_def - .instantiate(vec![TypeArg::new_var_use(idx, Self::TP)]) - .unwrap(), - ) - } -} - -impl MakeOpDef for ListOp { - fn from_def(op_def: &OpDef) -> Result { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) - } - - fn extension(&self) -> ExtensionId { - EXTENSION_ID.to_owned() - } - - fn extension_ref(&self) -> Weak { - Arc::downgrade(&EXTENSION) - } - - /// Add an operation implemented as an [MakeOpDef], which can provide the data - /// required to define an [OpDef], to an extension. - // - // This method is re-defined here since we need to pass the list type def while computing the signature, - // to avoid recursive loops initializing the extension. - fn add_to_extension( - &self, - extension: &mut Extension, - extension_ref: &Weak, - ) -> Result<(), ExtensionBuildError> { - let sig = self.compute_signature(extension.get_type(&LIST_TYPENAME).unwrap()); - let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; - - self.post_opdef(def); - - Ok(()) - } - - fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { - self.compute_signature(list_type_def()) - } - - fn description(&self) -> String { - use ListOp::*; - - match self { - pop => "Pop from the back of list. Returns an optional value.", - push => "Push to the back of list", - get => "Lookup an element in a list by index. Panics if the index is out of bounds.", - set => "Replace the element at index `i` with value `v`.", - insert => "Insert an element at index `i`. Elements at higher indices are shifted one position to the right. Panics if the index is out of bounds.", - length => "Get the length of a list", - } - .into() - } - - fn post_opdef(&self, def: &mut OpDef) { - list_fold::set_fold(self, def) - } -} - -lazy_static! { - /// Extension for list operations. - pub static ref EXTENSION: Arc = { - Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { - extension.add_type( - LIST_TYPENAME, - vec![ListOp::TP], - "Generic dynamically sized list of type T.".into(), - TypeDefBound::from_params(vec![0]), - extension_ref - ) - .unwrap(); - - // The list type must be defined before the operations are added. - ListOp::load_all_ops(extension, extension_ref).unwrap(); - }) - }; - - /// Registry of extensions required to validate list operations. - pub static ref COLLECTIONS_REGISTRY: ExtensionRegistry = ExtensionRegistry::new([ - PRELUDE.clone(), - EXTENSION.clone(), - ]); -} - -impl MakeRegisteredOp for ListOp { - fn extension_id(&self) -> ExtensionId { - EXTENSION_ID.to_owned() - } - - fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { - &COLLECTIONS_REGISTRY - } -} - -/// Get the type of a list of `elem_type` as a `CustomType`. -pub fn list_type_def() -> &'static TypeDef { - // This must not be called while the extension is being built. - EXTENSION.get_type(&LIST_TYPENAME).unwrap() -} - -/// Get the type of a list of `elem_type` as a `CustomType`. -pub fn list_custom_type(elem_type: Type) -> CustomType { - list_type_def() - .instantiate(vec![TypeArg::Type { ty: elem_type }]) - .unwrap() -} - -/// Get the `Type` of a list of `elem_type`. -pub fn list_type(elem_type: Type) -> Type { - list_custom_type(elem_type).into() -} - -/// A list operation with a concrete element type. -/// -/// See [ListOp] for the parametric version. -#[derive(Debug, Clone, PartialEq)] -pub struct ListOpInst { - op: ListOp, - elem_type: Type, -} - -impl NamedOp for ListOpInst { - fn name(&self) -> OpName { - let name: &str = self.op.into(); - name.into() - } -} - -impl MakeExtensionOp for ListOpInst { - fn from_extension_op( - ext_op: &ExtensionOp, - ) -> Result { - let [TypeArg::Type { ty }] = ext_op.args() else { - return Err(SignatureError::InvalidTypeArgs.into()); - }; - let name = ext_op.def().name(); - let Ok(op) = ListOp::from_str(name) else { - return Err(OpLoadError::NotMember(name.to_string())); - }; - - Ok(Self { - elem_type: ty.clone(), - op, - }) - } - - fn type_args(&self) -> Vec { - vec![TypeArg::Type { - ty: self.elem_type.clone(), - }] - } -} - -impl ListOpInst { - /// Convert this list operation to an [`ExtensionOp`] by providing a - /// registry to validate the element type against. - pub fn to_extension_op(self, elem_type_registry: &ExtensionRegistry) -> Option { - let registry = ExtensionRegistry::new( - elem_type_registry - .clone() - .into_iter() - // ignore self if already in registry - .filter(|ext| ext.name() != EXTENSION.name()) - .chain(std::iter::once(EXTENSION.to_owned())), - ); - ExtensionOp::new( - registry.get(&EXTENSION_ID)?.get_op(&self.name())?.clone(), - self.type_args(), - ®istry, - ) - .ok() - } -} - -#[cfg(test)] -mod test { - use rstest::rstest; - - use crate::extension::prelude::{ - const_fail_tuple, const_none, const_ok_tuple, const_some_tuple, - }; - use crate::ops::OpTrait; - use crate::PortIndex; - use crate::{ - extension::{ - prelude::{qb_t, usize_t, ConstUsize}, - PRELUDE, - }, - std_extensions::arithmetic::float_types::{self, float64_type, ConstF64}, - types::TypeRow, - }; - - use super::*; - - #[test] - fn test_extension() { - assert_eq!(&ListOp::push.extension_id(), EXTENSION.name()); - assert_eq!(&ListOp::push.extension(), EXTENSION.name()); - assert!(ListOp::pop.registry().contains(EXTENSION.name())); - for (_, op_def) in EXTENSION.operations() { - assert_eq!(op_def.extension_id(), &EXTENSION_ID); - } - } - - #[test] - fn test_list() { - let list_def = list_type_def(); - - let list_type = list_def - .instantiate([TypeArg::Type { ty: usize_t() }]) - .unwrap(); - - assert!(list_def - .instantiate([TypeArg::BoundedNat { n: 3 }]) - .is_err()); - - list_def.check_custom(&list_type).unwrap(); - let list_value = ListValue(vec![ConstUsize::new(3).into()], usize_t()); - - list_value.validate().unwrap(); - - let wrong_list_value = ListValue(vec![ConstF64::new(1.2).into()], usize_t()); - assert!(wrong_list_value.validate().is_err()); - } - - #[test] - fn test_list_ops() { - let reg = ExtensionRegistry::new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]); - let pop_op = ListOp::pop.with_type(qb_t()); - let pop_ext = pop_op.clone().to_extension_op(®).unwrap(); - assert_eq!(ListOpInst::from_extension_op(&pop_ext).unwrap(), pop_op); - let pop_sig = pop_ext.dataflow_signature().unwrap(); - - let list_t = list_type(qb_t()); - - let both_row: TypeRow = vec![list_t.clone(), option_type(qb_t()).into()].into(); - let just_list_row: TypeRow = vec![list_t].into(); - assert_eq!(pop_sig.input(), &just_list_row); - assert_eq!(pop_sig.output(), &both_row); - - let push_op = ListOp::push.with_type(float64_type()); - let push_ext = push_op.clone().to_extension_op(®).unwrap(); - assert_eq!(ListOpInst::from_extension_op(&push_ext).unwrap(), push_op); - let push_sig = push_ext.dataflow_signature().unwrap(); - - let list_t = list_type(float64_type()); - - let both_row: TypeRow = vec![list_t.clone(), float64_type()].into(); - let just_list_row: TypeRow = vec![list_t].into(); - - assert_eq!(push_sig.input(), &both_row); - assert_eq!(push_sig.output(), &just_list_row); - } - - /// Values used in the `list_fold` test cases. - #[derive(Debug, Clone, PartialEq, Eq)] - enum TestVal { - Idx(usize), - List(Vec), - Elem(usize), - Some(Vec), - None(TypeRow), - Ok(Vec, TypeRow), - Err(TypeRow, Vec), - } - - impl TestVal { - fn to_value(&self) -> Value { - match self { - TestVal::Idx(i) => Value::extension(ConstUsize::new(*i as u64)), - TestVal::Elem(e) => Value::extension(ConstUsize::new(*e as u64)), - TestVal::List(l) => { - let elems = l - .iter() - .map(|&i| Value::extension(ConstUsize::new(i as u64))) - .collect(); - Value::extension(ListValue(elems, usize_t())) - } - TestVal::Some(l) => { - let elems = l.iter().map(TestVal::to_value); - const_some_tuple(elems) - } - TestVal::None(tr) => const_none(tr.clone()), - TestVal::Ok(l, tr) => { - let elems = l.iter().map(TestVal::to_value); - const_ok_tuple(elems, tr.clone()) - } - TestVal::Err(tr, l) => { - let elems = l.iter().map(TestVal::to_value); - const_fail_tuple(elems, tr.clone()) - } - } - } - } - - #[rstest] - #[case::pop(ListOp::pop, &[TestVal::List(vec![77,88, 42])], &[TestVal::List(vec![77,88]), TestVal::Some(vec![TestVal::Elem(42)])])] - #[case::pop_empty(ListOp::pop, &[TestVal::List(vec![])], &[TestVal::List(vec![]), TestVal::None(vec![usize_t()].into())])] - #[case::push(ListOp::push, &[TestVal::List(vec![77,88]), TestVal::Elem(42)], &[TestVal::List(vec![77,88,42])])] - #[case::set(ListOp::set, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1), TestVal::Elem(99)], &[TestVal::List(vec![77,99,42]), TestVal::Ok(vec![TestVal::Elem(88)], vec![usize_t()].into())])] - #[case::set_invalid(ListOp::set, &[TestVal::List(vec![77,88,42]), TestVal::Idx(123), TestVal::Elem(99)], &[TestVal::List(vec![77,88,42]), TestVal::Err(vec![usize_t()].into(), vec![TestVal::Elem(99)])])] - #[case::get(ListOp::get, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1)], &[TestVal::Some(vec![TestVal::Elem(88)])])] - #[case::get_invalid(ListOp::get, &[TestVal::List(vec![77,88,42]), TestVal::Idx(99)], &[TestVal::None(vec![usize_t()].into())])] - #[case::insert(ListOp::insert, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1), TestVal::Elem(99)], &[TestVal::List(vec![77,99,88,42]), TestVal::Ok(vec![], vec![usize_t()].into())])] - #[case::insert_invalid(ListOp::insert, &[TestVal::List(vec![77,88,42]), TestVal::Idx(52), TestVal::Elem(99)], &[TestVal::List(vec![77,88,42]), TestVal::Err(Type::UNIT.into(), vec![TestVal::Elem(99)])])] - #[case::length(ListOp::length, &[TestVal::List(vec![77,88,42])], &[TestVal::Elem(3)])] - fn list_fold(#[case] op: ListOp, #[case] inputs: &[TestVal], #[case] outputs: &[TestVal]) { - let consts: Vec<_> = inputs - .iter() - .enumerate() - .map(|(i, x)| (i.into(), x.to_value())) - .collect(); - - let res = op - .with_type(usize_t()) - .to_extension_op(&COLLECTIONS_REGISTRY) - .unwrap() - .constant_fold(&consts) - .unwrap(); - - for (i, expected) in outputs.iter().enumerate() { - let expected = expected.to_value(); - let res_val = res - .iter() - .find(|(port, _)| port.index() == i) - .unwrap() - .1 - .clone(); - - assert_eq!(res_val, expected); - } - } -} +pub mod list; diff --git a/hugr-core/src/std_extensions/collections/list.rs b/hugr-core/src/std_extensions/collections/list.rs new file mode 100644 index 000000000..e72c6f6a6 --- /dev/null +++ b/hugr-core/src/std_extensions/collections/list.rs @@ -0,0 +1,544 @@ +//! List type and operations. + +mod list_fold; + +use std::hash::{Hash, Hasher}; + +use std::str::FromStr; +use std::sync::{Arc, Weak}; + +use itertools::Itertools; +use lazy_static::lazy_static; +use serde::{Deserialize, Serialize}; +use strum_macros::{EnumIter, EnumString, IntoStaticStr}; + +use crate::extension::prelude::{either_type, option_type, usize_t}; +use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; +use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc, PRELUDE}; +use crate::ops::constant::{maybe_hash_values, TryHash, ValueName}; +use crate::ops::{OpName, Value}; +use crate::types::{TypeName, TypeRowRV}; +use crate::{ + extension::{ + simple_op::{MakeExtensionOp, OpLoadError}, + ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef, TypeDefBound, + }, + ops::constant::CustomConst, + ops::{custom::ExtensionOp, NamedOp}, + types::{ + type_param::{TypeArg, TypeParam}, + CustomCheckFailure, CustomType, FuncValueType, PolyFuncTypeRV, Type, TypeBound, + }, + Extension, +}; + +/// Reported unique name of the list type. +pub const LIST_TYPENAME: TypeName = TypeName::new_inline("List"); +/// Reported unique name of the extension +pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections.list"); +/// Extension version. +pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +/// Dynamically sized list of values, all of the same type. +pub struct ListValue(Vec, Type); + +impl ListValue { + /// Create a new [CustomConst] for a list of values of type `typ`. + /// That all values ore of type `typ` is not checked here. + pub fn new(typ: Type, contents: impl IntoIterator) -> Self { + Self(contents.into_iter().collect_vec(), typ) + } + + /// Create a new [CustomConst] for an empty list of values of type `typ`. + pub fn new_empty(typ: Type) -> Self { + Self(vec![], typ) + } + + /// Returns the type of the `[ListValue]` as a `[CustomType]`.` + pub fn custom_type(&self) -> CustomType { + list_custom_type(self.1.clone()) + } + + /// Returns the type of values inside the `[ListValue]`. + pub fn get_element_type(&self) -> &Type { + &self.1 + } + + /// Returns the values contained inside the `[ListValue]`. + pub fn get_contents(&self) -> &[Value] { + &self.0 + } +} + +impl TryHash for ListValue { + fn try_hash(&self, mut st: &mut dyn Hasher) -> bool { + maybe_hash_values(&self.0, &mut st) && { + self.1.hash(&mut st); + true + } + } +} + +#[typetag::serde] +impl CustomConst for ListValue { + fn name(&self) -> ValueName { + ValueName::new_inline("list") + } + + fn get_type(&self) -> Type { + self.custom_type().into() + } + + fn validate(&self) -> Result<(), CustomCheckFailure> { + let typ = self.custom_type(); + let error = || { + // TODO more bespoke errors + CustomCheckFailure::Message("List type check fail.".to_string()) + }; + + EXTENSION + .get_type(&LIST_TYPENAME) + .unwrap() + .check_custom(&typ) + .map_err(|_| error())?; + + // constant can only hold classic type. + let [TypeArg::Type { ty }] = typ.args() else { + return Err(error()); + }; + + // check all values are instances of the element type + for v in &self.0 { + if v.get_type() != *ty { + return Err(error()); + } + } + + Ok(()) + } + + fn equal_consts(&self, other: &dyn CustomConst) -> bool { + crate::ops::constant::downcast_equal_consts(self, other) + } + + fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::union_over(self.0.iter().map(Value::extension_reqs)) + .union(EXTENSION_ID.into()) + } +} + +/// A list operation +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] +#[allow(non_camel_case_types)] +#[non_exhaustive] +pub enum ListOp { + /// Pop from the end of list. Return an optional value. + pop, + /// Push to end of list. Return the new list. + push, + /// Lookup an element in a list by index. + get, + /// Replace the element at index `i` with value `v`, and return the old value. + /// + /// If the index is out of bounds, returns the input value as an error. + set, + /// Insert an element at index `i`. + /// + /// Elements at higher indices are shifted one position to the right. + /// Returns an Err with the element if the index is out of bounds. + insert, + /// Get the length of a list. + length, +} + +impl ListOp { + /// Type parameter used in the list types. + const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; + + /// Instantiate a list operation with an `element_type`. + pub fn with_type(self, element_type: Type) -> ListOpInst { + ListOpInst { + elem_type: element_type, + op: self, + } + } + + /// Compute the signature of the operation, given the list type definition. + fn compute_signature(self, list_type_def: &TypeDef) -> SignatureFunc { + use ListOp::*; + let e = Type::new_var_use(0, TypeBound::Any); + let l = self.list_type(list_type_def, 0); + match self { + pop => self + .list_polytype(vec![l.clone()], vec![l, Type::from(option_type(e))]) + .into(), + push => self.list_polytype(vec![l.clone(), e], vec![l]).into(), + get => self + .list_polytype(vec![l, usize_t()], vec![Type::from(option_type(e))]) + .into(), + set => self + .list_polytype( + vec![l.clone(), usize_t(), e.clone()], + vec![l, Type::from(either_type(e.clone(), e))], + ) + .into(), + insert => self + .list_polytype( + vec![l.clone(), usize_t(), e.clone()], + vec![l, either_type(e, Type::UNIT).into()], + ) + .into(), + length => self + .list_polytype(vec![l.clone()], vec![l, usize_t()]) + .into(), + } + } + + /// Compute a polymorphic function type for a list operation. + fn list_polytype( + self, + input: impl Into, + output: impl Into, + ) -> PolyFuncTypeRV { + PolyFuncTypeRV::new(vec![Self::TP], FuncValueType::new(input, output)) + } + + /// Returns the type of a generic list, associated with the element type parameter at index `idx`. + fn list_type(self, list_type_def: &TypeDef, idx: usize) -> Type { + Type::new_extension( + list_type_def + .instantiate(vec![TypeArg::new_var_use(idx, Self::TP)]) + .unwrap(), + ) + } +} + +impl MakeOpDef for ListOp { + fn from_def(op_def: &OpDef) -> Result { + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) + } + + fn extension(&self) -> ExtensionId { + EXTENSION_ID.to_owned() + } + + fn extension_ref(&self) -> Weak { + Arc::downgrade(&EXTENSION) + } + + /// Add an operation implemented as an [MakeOpDef], which can provide the data + /// required to define an [OpDef], to an extension. + // + // This method is re-defined here since we need to pass the list type def while computing the signature, + // to avoid recursive loops initializing the extension. + fn add_to_extension( + &self, + extension: &mut Extension, + extension_ref: &Weak, + ) -> Result<(), ExtensionBuildError> { + let sig = self.compute_signature(extension.get_type(&LIST_TYPENAME).unwrap()); + let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; + + self.post_opdef(def); + + Ok(()) + } + + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { + self.compute_signature(list_type_def()) + } + + fn description(&self) -> String { + use ListOp::*; + + match self { + pop => "Pop from the back of list. Returns an optional value.", + push => "Push to the back of list", + get => "Lookup an element in a list by index. Panics if the index is out of bounds.", + set => "Replace the element at index `i` with value `v`.", + insert => "Insert an element at index `i`. Elements at higher indices are shifted one position to the right. Panics if the index is out of bounds.", + length => "Get the length of a list", + } + .into() + } + + fn post_opdef(&self, def: &mut OpDef) { + list_fold::set_fold(self, def) + } +} + +lazy_static! { + /// Extension for list operations. + pub static ref EXTENSION: Arc = { + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension.add_type( + LIST_TYPENAME, + vec![ListOp::TP], + "Generic dynamically sized list of type T.".into(), + TypeDefBound::from_params(vec![0]), + extension_ref + ) + .unwrap(); + + // The list type must be defined before the operations are added. + ListOp::load_all_ops(extension, extension_ref).unwrap(); + }) + }; + + /// Registry of extensions required to validate list operations. + pub static ref LIST_REGISTRY: ExtensionRegistry = ExtensionRegistry::new([ + PRELUDE.clone(), + EXTENSION.clone(), + ]); +} + +impl MakeRegisteredOp for ListOp { + fn extension_id(&self) -> ExtensionId { + EXTENSION_ID.to_owned() + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { + &LIST_REGISTRY + } +} + +/// Get the type of a list of `elem_type` as a `CustomType`. +pub fn list_type_def() -> &'static TypeDef { + // This must not be called while the extension is being built. + EXTENSION.get_type(&LIST_TYPENAME).unwrap() +} + +/// Get the type of a list of `elem_type` as a `CustomType`. +pub fn list_custom_type(elem_type: Type) -> CustomType { + list_type_def() + .instantiate(vec![TypeArg::Type { ty: elem_type }]) + .unwrap() +} + +/// Get the `Type` of a list of `elem_type`. +pub fn list_type(elem_type: Type) -> Type { + list_custom_type(elem_type).into() +} + +/// A list operation with a concrete element type. +/// +/// See [ListOp] for the parametric version. +#[derive(Debug, Clone, PartialEq)] +pub struct ListOpInst { + op: ListOp, + elem_type: Type, +} + +impl NamedOp for ListOpInst { + fn name(&self) -> OpName { + let name: &str = self.op.into(); + name.into() + } +} + +impl MakeExtensionOp for ListOpInst { + fn from_extension_op( + ext_op: &ExtensionOp, + ) -> Result { + let [TypeArg::Type { ty }] = ext_op.args() else { + return Err(SignatureError::InvalidTypeArgs.into()); + }; + let name = ext_op.def().name(); + let Ok(op) = ListOp::from_str(name) else { + return Err(OpLoadError::NotMember(name.to_string())); + }; + + Ok(Self { + elem_type: ty.clone(), + op, + }) + } + + fn type_args(&self) -> Vec { + vec![TypeArg::Type { + ty: self.elem_type.clone(), + }] + } +} + +impl ListOpInst { + /// Convert this list operation to an [`ExtensionOp`] by providing a + /// registry to validate the element type against. + pub fn to_extension_op(self, elem_type_registry: &ExtensionRegistry) -> Option { + let registry = ExtensionRegistry::new( + elem_type_registry + .clone() + .into_iter() + // ignore self if already in registry + .filter(|ext| ext.name() != EXTENSION.name()) + .chain(std::iter::once(EXTENSION.to_owned())), + ); + ExtensionOp::new( + registry.get(&EXTENSION_ID)?.get_op(&self.name())?.clone(), + self.type_args(), + ®istry, + ) + .ok() + } +} + +#[cfg(test)] +mod test { + use rstest::rstest; + + use crate::extension::prelude::{ + const_fail_tuple, const_none, const_ok_tuple, const_some_tuple, + }; + use crate::ops::OpTrait; + use crate::PortIndex; + use crate::{ + extension::{ + prelude::{qb_t, usize_t, ConstUsize}, + PRELUDE, + }, + std_extensions::arithmetic::float_types::{self, float64_type, ConstF64}, + types::TypeRow, + }; + + use super::*; + + #[test] + fn test_extension() { + assert_eq!(&ListOp::push.extension_id(), EXTENSION.name()); + assert_eq!(&ListOp::push.extension(), EXTENSION.name()); + assert!(ListOp::pop.registry().contains(EXTENSION.name())); + for (_, op_def) in EXTENSION.operations() { + assert_eq!(op_def.extension_id(), &EXTENSION_ID); + } + } + + #[test] + fn test_list() { + let list_def = list_type_def(); + + let list_type = list_def + .instantiate([TypeArg::Type { ty: usize_t() }]) + .unwrap(); + + assert!(list_def + .instantiate([TypeArg::BoundedNat { n: 3 }]) + .is_err()); + + list_def.check_custom(&list_type).unwrap(); + let list_value = ListValue(vec![ConstUsize::new(3).into()], usize_t()); + + list_value.validate().unwrap(); + + let wrong_list_value = ListValue(vec![ConstF64::new(1.2).into()], usize_t()); + assert!(wrong_list_value.validate().is_err()); + } + + #[test] + fn test_list_ops() { + let reg = ExtensionRegistry::new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]); + let pop_op = ListOp::pop.with_type(qb_t()); + let pop_ext = pop_op.clone().to_extension_op(®).unwrap(); + assert_eq!(ListOpInst::from_extension_op(&pop_ext).unwrap(), pop_op); + let pop_sig = pop_ext.dataflow_signature().unwrap(); + + let list_t = list_type(qb_t()); + + let both_row: TypeRow = vec![list_t.clone(), option_type(qb_t()).into()].into(); + let just_list_row: TypeRow = vec![list_t].into(); + assert_eq!(pop_sig.input(), &just_list_row); + assert_eq!(pop_sig.output(), &both_row); + + let push_op = ListOp::push.with_type(float64_type()); + let push_ext = push_op.clone().to_extension_op(®).unwrap(); + assert_eq!(ListOpInst::from_extension_op(&push_ext).unwrap(), push_op); + let push_sig = push_ext.dataflow_signature().unwrap(); + + let list_t = list_type(float64_type()); + + let both_row: TypeRow = vec![list_t.clone(), float64_type()].into(); + let just_list_row: TypeRow = vec![list_t].into(); + + assert_eq!(push_sig.input(), &both_row); + assert_eq!(push_sig.output(), &just_list_row); + } + + /// Values used in the `list_fold` test cases. + #[derive(Debug, Clone, PartialEq, Eq)] + enum TestVal { + Idx(usize), + List(Vec), + Elem(usize), + Some(Vec), + None(TypeRow), + Ok(Vec, TypeRow), + Err(TypeRow, Vec), + } + + impl TestVal { + fn to_value(&self) -> Value { + match self { + TestVal::Idx(i) => Value::extension(ConstUsize::new(*i as u64)), + TestVal::Elem(e) => Value::extension(ConstUsize::new(*e as u64)), + TestVal::List(l) => { + let elems = l + .iter() + .map(|&i| Value::extension(ConstUsize::new(i as u64))) + .collect(); + Value::extension(ListValue(elems, usize_t())) + } + TestVal::Some(l) => { + let elems = l.iter().map(TestVal::to_value); + const_some_tuple(elems) + } + TestVal::None(tr) => const_none(tr.clone()), + TestVal::Ok(l, tr) => { + let elems = l.iter().map(TestVal::to_value); + const_ok_tuple(elems, tr.clone()) + } + TestVal::Err(tr, l) => { + let elems = l.iter().map(TestVal::to_value); + const_fail_tuple(elems, tr.clone()) + } + } + } + } + + #[rstest] + #[case::pop(ListOp::pop, &[TestVal::List(vec![77,88, 42])], &[TestVal::List(vec![77,88]), TestVal::Some(vec![TestVal::Elem(42)])])] + #[case::pop_empty(ListOp::pop, &[TestVal::List(vec![])], &[TestVal::List(vec![]), TestVal::None(vec![usize_t()].into())])] + #[case::push(ListOp::push, &[TestVal::List(vec![77,88]), TestVal::Elem(42)], &[TestVal::List(vec![77,88,42])])] + #[case::set(ListOp::set, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1), TestVal::Elem(99)], &[TestVal::List(vec![77,99,42]), TestVal::Ok(vec![TestVal::Elem(88)], vec![usize_t()].into())])] + #[case::set_invalid(ListOp::set, &[TestVal::List(vec![77,88,42]), TestVal::Idx(123), TestVal::Elem(99)], &[TestVal::List(vec![77,88,42]), TestVal::Err(vec![usize_t()].into(), vec![TestVal::Elem(99)])])] + #[case::get(ListOp::get, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1)], &[TestVal::Some(vec![TestVal::Elem(88)])])] + #[case::get_invalid(ListOp::get, &[TestVal::List(vec![77,88,42]), TestVal::Idx(99)], &[TestVal::None(vec![usize_t()].into())])] + #[case::insert(ListOp::insert, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1), TestVal::Elem(99)], &[TestVal::List(vec![77,99,88,42]), TestVal::Ok(vec![], vec![usize_t()].into())])] + #[case::insert_invalid(ListOp::insert, &[TestVal::List(vec![77,88,42]), TestVal::Idx(52), TestVal::Elem(99)], &[TestVal::List(vec![77,88,42]), TestVal::Err(Type::UNIT.into(), vec![TestVal::Elem(99)])])] + #[case::length(ListOp::length, &[TestVal::List(vec![77,88,42])], &[TestVal::Elem(3)])] + fn list_fold(#[case] op: ListOp, #[case] inputs: &[TestVal], #[case] outputs: &[TestVal]) { + let consts: Vec<_> = inputs + .iter() + .enumerate() + .map(|(i, x)| (i.into(), x.to_value())) + .collect(); + + let res = op + .with_type(usize_t()) + .to_extension_op(&LIST_REGISTRY) + .unwrap() + .constant_fold(&consts) + .unwrap(); + + for (i, expected) in outputs.iter().enumerate() { + let expected = expected.to_value(); + let res_val = res + .iter() + .find(|(port, _)| port.index() == i) + .unwrap() + .1 + .clone(); + + assert_eq!(res_val, expected); + } + } +} diff --git a/hugr-core/src/std_extensions/collections/list_fold.rs b/hugr-core/src/std_extensions/collections/list/list_fold.rs similarity index 100% rename from hugr-core/src/std_extensions/collections/list_fold.rs rename to hugr-core/src/std_extensions/collections/list/list_fold.rs diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index 66dff8245..9122c1846 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -165,7 +165,7 @@ pub(crate) mod test { ExtensionId, ExtensionRegistry, SignatureError, TypeDefBound, EMPTY_REG, PRELUDE, PRELUDE_REGISTRY, }; - use crate::std_extensions::collections::{EXTENSION, LIST_TYPENAME}; + use crate::std_extensions::collections::list; use crate::types::signature::FuncTypeBase; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::{ @@ -177,7 +177,7 @@ pub(crate) mod test { lazy_static! { static ref REGISTRY: ExtensionRegistry = - ExtensionRegistry::new([PRELUDE.to_owned(), EXTENSION.to_owned()]); + ExtensionRegistry::new([PRELUDE.to_owned(), list::EXTENSION.to_owned()]); } impl PolyFuncTypeBase { @@ -194,7 +194,7 @@ pub(crate) mod test { #[test] fn test_opaque() -> Result<(), SignatureError> { - let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap(); + let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap(); let tyvar = TypeArg::new_var_use(0, TypeBound::Any.into()); let list_of_var = Type::new_extension(list_def.instantiate([tyvar.clone()])?); let list_len = PolyFuncTypeBase::new_validated( @@ -291,7 +291,7 @@ pub(crate) mod test { fn test_misused_variables() -> Result<(), SignatureError> { // Variables in args have different bounds from variable declaration let tv = TypeArg::new_var_use(0, TypeBound::Copyable.into()); - let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap(); + let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap(); let body_type = Signature::new_endo(Type::new_extension(list_def.instantiate([tv])?)); for decl in [ TypeParam::Extensions, diff --git a/hugr-llvm/src/emit/test.rs b/hugr-llvm/src/emit/test.rs index 8bf8856c1..cdf3e024c 100644 --- a/hugr-llvm/src/emit/test.rs +++ b/hugr-llvm/src/emit/test.rs @@ -153,7 +153,7 @@ impl SimpleHugrConfig { float_ops::EXTENSION_ID, conversions::EXTENSION_ID, logic::EXTENSION_ID, - collections::EXTENSION_ID, + collections::list::EXTENSION_ID, ]), ), ) diff --git a/hugr-llvm/src/extension/collections.rs b/hugr-llvm/src/extension/collections.rs index 5a60c6f54..796c4ed4b 100644 --- a/hugr-llvm/src/extension/collections.rs +++ b/hugr-llvm/src/extension/collections.rs @@ -1,445 +1,3 @@ -use anyhow::{bail, Ok, Result}; -use hugr_core::{ - ops::{ExtensionOp, NamedOp}, - std_extensions::collections::{self, ListOp, ListValue}, - types::{SumType, Type, TypeArg}, - HugrView, -}; -use inkwell::values::FunctionValue; -use inkwell::{ - types::{BasicType, BasicTypeEnum, FunctionType}, - values::{BasicValueEnum, PointerValue}, - AddressSpace, -}; +//! Emission logic for collections. -use crate::emit::func::{build_ok_or_else, build_option}; -use crate::{ - custom::{CodegenExtension, CodegenExtsBuilder}, - emit::{emit_value, func::EmitFuncContext, EmitOpArgs}, - types::TypingSession, -}; - -/// Runtime functions that implement operations on lists. -#[derive(Clone, Copy, Debug, PartialEq, Hash)] -#[non_exhaustive] -pub enum CollectionsRtFunc { - New, - Push, - Pop, - Get, - Set, - Insert, - Length, -} - -impl CollectionsRtFunc { - /// The signature of a given [CollectionsRtFunc]. - /// - /// Requires a [CollectionsCodegen] to determine the type of lists. - pub fn signature<'c>( - self, - ts: TypingSession<'c, '_>, - ccg: &(impl CollectionsCodegen + 'c), - ) -> FunctionType<'c> { - let iwc = ts.iw_context(); - match self { - CollectionsRtFunc::New => ccg.list_type(ts).fn_type( - &[ - iwc.i64_type().into(), // Capacity - iwc.i64_type().into(), // Single element size in bytes - iwc.i64_type().into(), // Element alignment - // Pointer to element destructor - iwc.i8_type().ptr_type(AddressSpace::default()).into(), - ], - false, - ), - CollectionsRtFunc::Push => iwc.void_type().fn_type( - &[ - ccg.list_type(ts).into(), - iwc.i8_type().ptr_type(AddressSpace::default()).into(), - ], - false, - ), - CollectionsRtFunc::Pop => iwc.bool_type().fn_type( - &[ - ccg.list_type(ts).into(), - iwc.i8_type().ptr_type(AddressSpace::default()).into(), - ], - false, - ), - CollectionsRtFunc::Get | CollectionsRtFunc::Set | CollectionsRtFunc::Insert => { - iwc.bool_type().fn_type( - &[ - ccg.list_type(ts).into(), - iwc.i64_type().into(), - iwc.i8_type().ptr_type(AddressSpace::default()).into(), - ], - false, - ) - } - CollectionsRtFunc::Length => iwc.i64_type().fn_type(&[ccg.list_type(ts).into()], false), - } - } - - /// Returns the extern function corresponding to this [CollectionsRtFunc]. - /// - /// Requires a [CollectionsCodegen] to determine the function signature. - pub fn get_extern<'c, H: HugrView>( - self, - ctx: &EmitFuncContext<'c, '_, H>, - ccg: &(impl CollectionsCodegen + 'c), - ) -> Result> { - ctx.get_extern_func( - ccg.rt_func_name(self), - self.signature(ctx.typing_session(), ccg), - ) - } -} - -impl From for CollectionsRtFunc { - fn from(op: ListOp) -> Self { - match op { - ListOp::get => CollectionsRtFunc::Get, - ListOp::set => CollectionsRtFunc::Set, - ListOp::push => CollectionsRtFunc::Push, - ListOp::pop => CollectionsRtFunc::Pop, - ListOp::insert => CollectionsRtFunc::Insert, - ListOp::length => CollectionsRtFunc::Length, - _ => todo!(), - } - } -} - -/// A helper trait for customising the lowering of [hugr_core::std_extensions::collections] -/// types, [hugr_core::ops::constant::CustomConst]s, and ops. -pub trait CollectionsCodegen: Clone { - /// Return the llvm type of [hugr_core::std_extensions::collections::LIST_TYPENAME]. - fn list_type<'c>(&self, session: TypingSession<'c, '_>) -> BasicTypeEnum<'c> { - session - .iw_context() - .i8_type() - .ptr_type(AddressSpace::default()) - .into() - } - - /// Return the name of a given [CollectionsRtFunc]. - fn rt_func_name(&self, func: CollectionsRtFunc) -> String { - match func { - CollectionsRtFunc::New => "__rt__list__new", - CollectionsRtFunc::Push => "__rt__list__push", - CollectionsRtFunc::Pop => "__rt__list__pop", - CollectionsRtFunc::Get => "__rt__list__get", - CollectionsRtFunc::Set => "__rt__list__set", - CollectionsRtFunc::Insert => "__rt__list__insert", - CollectionsRtFunc::Length => "__rt__list__length", - } - .into() - } -} - -/// A trivial implementation of [CollectionsCodegen] which passes all methods -/// through to their default implementations. -#[derive(Default, Clone)] -pub struct DefaultCollectionsCodegen; - -impl CollectionsCodegen for DefaultCollectionsCodegen {} - -#[derive(Clone, Debug, Default)] -pub struct CollectionsCodegenExtension(CCG); - -impl CollectionsCodegenExtension { - pub fn new(ccg: CCG) -> Self { - Self(ccg) - } -} - -impl From for CollectionsCodegenExtension { - fn from(ccg: CCG) -> Self { - Self::new(ccg) - } -} - -impl CodegenExtension for CollectionsCodegenExtension { - fn add_extension<'a, H: HugrView + 'a>( - self, - builder: CodegenExtsBuilder<'a, H>, - ) -> CodegenExtsBuilder<'a, H> - where - Self: 'a, - { - builder - .custom_type((collections::EXTENSION_ID, collections::LIST_TYPENAME), { - let ccg = self.0.clone(); - move |ts, _hugr_type| Ok(ccg.list_type(ts).as_basic_type_enum()) - }) - .custom_const::({ - let ccg = self.0.clone(); - move |ctx, k| emit_list_value(ctx, &ccg, k) - }) - .simple_extension_op::(move |ctx, args, op| { - emit_list_op(ctx, &self.0, args, op) - }) - } -} - -impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { - /// Add a [CollectionsCodegenExtension] to the given [CodegenExtsBuilder] using `ccg` - /// as the implementation. - pub fn add_default_collections_extensions(self) -> Self { - self.add_collections_extensions(DefaultCollectionsCodegen) - } - - /// Add a [CollectionsCodegenExtension] to the given [CodegenExtsBuilder] using - /// [DefaultCollectionsCodegen] as the implementation. - pub fn add_collections_extensions(self, ccg: impl CollectionsCodegen + 'a) -> Self { - self.add_extension(CollectionsCodegenExtension::from(ccg)) - } -} - -fn emit_list_op<'c, H: HugrView>( - ctx: &mut EmitFuncContext<'c, '_, H>, - ccg: &(impl CollectionsCodegen + 'c), - args: EmitOpArgs<'c, '_, ExtensionOp, H>, - op: ListOp, -) -> Result<()> { - let hugr_elem_ty = match args.node().args() { - [TypeArg::Type { ty }] => ty.clone(), - _ => { - bail!("Collections: invalid type args for list op"); - } - }; - let elem_ty = ctx.llvm_type(&hugr_elem_ty)?; - let func = CollectionsRtFunc::get_extern(op.into(), ctx, ccg)?; - match op { - ListOp::push => { - let [list, elem] = args.inputs.try_into().unwrap(); - let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?; - ctx.builder() - .build_call(func, &[list.into(), elem_ptr.into()], "")?; - args.outputs.finish(ctx.builder(), vec![list])?; - } - ListOp::pop => { - let [list] = args.inputs.try_into().unwrap(); - let out_ptr = build_alloca_i8_ptr(ctx, elem_ty, None)?; - let ok = ctx - .builder() - .build_call(func, &[list.into(), out_ptr.into()], "")? - .try_as_basic_value() - .unwrap_left() - .into_int_value(); - let elem = build_load_i8_ptr(ctx, out_ptr, elem_ty)?; - let elem_opt = build_option(ctx, ok, elem, hugr_elem_ty)?; - args.outputs.finish(ctx.builder(), vec![list, elem_opt])?; - } - ListOp::get => { - let [list, idx] = args.inputs.try_into().unwrap(); - let out_ptr = build_alloca_i8_ptr(ctx, elem_ty, None)?; - let ok = ctx - .builder() - .build_call(func, &[list.into(), idx.into(), out_ptr.into()], "")? - .try_as_basic_value() - .unwrap_left() - .into_int_value(); - let elem = build_load_i8_ptr(ctx, out_ptr, elem_ty)?; - let elem_opt = build_option(ctx, ok, elem, hugr_elem_ty)?; - args.outputs.finish(ctx.builder(), vec![elem_opt])?; - } - ListOp::set => { - let [list, idx, elem] = args.inputs.try_into().unwrap(); - let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?; - let ok = ctx - .builder() - .build_call(func, &[list.into(), idx.into(), elem_ptr.into()], "")? - .try_as_basic_value() - .unwrap_left() - .into_int_value(); - let old_elem = build_load_i8_ptr(ctx, elem_ptr, elem.get_type())?; - let ok_or = - build_ok_or_else(ctx, ok, elem, hugr_elem_ty.clone(), old_elem, hugr_elem_ty)?; - args.outputs.finish(ctx.builder(), vec![list, ok_or])?; - } - ListOp::insert => { - let [list, idx, elem] = args.inputs.try_into().unwrap(); - let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?; - let ok = ctx - .builder() - .build_call(func, &[list.into(), idx.into(), elem_ptr.into()], "")? - .try_as_basic_value() - .unwrap_left() - .into_int_value(); - let unit = - ctx.llvm_sum_type(SumType::new_unary(1))? - .build_tag(ctx.builder(), 0, vec![])?; - let ok_or = build_ok_or_else(ctx, ok, unit, Type::UNIT, elem, hugr_elem_ty)?; - args.outputs.finish(ctx.builder(), vec![list, ok_or])?; - } - ListOp::length => { - let [list] = args.inputs.try_into().unwrap(); - let length = ctx - .builder() - .build_call(func, &[list.into()], "")? - .try_as_basic_value() - .unwrap_left() - .into_int_value(); - args.outputs - .finish(ctx.builder(), vec![list, length.into()])?; - } - _ => bail!("Collections: unimplemented op: {}", op.name()), - } - Ok(()) -} - -fn emit_list_value<'c, H: HugrView>( - ctx: &mut EmitFuncContext<'c, '_, H>, - ccg: &(impl CollectionsCodegen + 'c), - val: &ListValue, -) -> Result> { - let elem_ty = ctx.llvm_type(val.get_element_type())?; - let iwc = ctx.typing_session().iw_context(); - let capacity = iwc - .i64_type() - .const_int(val.get_contents().len() as u64, false); - let elem_size = elem_ty.size_of().unwrap(); - let alignment = iwc.i64_type().const_int(8, false); - // TODO: Lookup destructor for elem_ty - let destructor = iwc.i8_type().ptr_type(AddressSpace::default()).const_null(); - let list = ctx - .builder() - .build_call( - CollectionsRtFunc::New.get_extern(ctx, ccg)?, - &[ - capacity.into(), - elem_size.into(), - alignment.into(), - destructor.into(), - ], - "", - )? - .try_as_basic_value() - .unwrap_left(); - // Push elements onto the list - let rt_push = CollectionsRtFunc::Push.get_extern(ctx, ccg)?; - for v in val.get_contents() { - let elem = emit_value(ctx, v)?; - let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?; - ctx.builder() - .build_call(rt_push, &[list.into(), elem_ptr.into()], "")?; - } - Ok(list) -} - -/// Helper function to allocate space on the stack for a given type. -/// -/// Optionally also stores a value at that location. -/// -/// Returns an i8 pointer to the allocated memory. -fn build_alloca_i8_ptr<'c, H: HugrView>( - ctx: &mut EmitFuncContext<'c, '_, H>, - ty: BasicTypeEnum<'c>, - value: Option>, -) -> Result> { - let builder = ctx.builder(); - let ptr = builder.build_alloca(ty, "")?; - if let Some(val) = value { - builder.build_store(ptr, val)?; - } - let i8_ptr = builder.build_pointer_cast( - ptr, - ctx.iw_context().i8_type().ptr_type(AddressSpace::default()), - "", - )?; - Ok(i8_ptr) -} - -/// Helper function to load a value from an i8 pointer. -fn build_load_i8_ptr<'c, H: HugrView>( - ctx: &mut EmitFuncContext<'c, '_, H>, - i8_ptr: PointerValue<'c>, - ty: BasicTypeEnum<'c>, -) -> Result> { - let builder = ctx.builder(); - let ptr = builder.build_pointer_cast(i8_ptr, ty.ptr_type(AddressSpace::default()), "")?; - let val = builder.build_load(ptr, "")?; - Ok(val) -} - -#[cfg(test)] -mod test { - use hugr_core::{ - builder::{Dataflow, DataflowSubContainer}, - extension::{ - prelude::{self, qb_t, usize_t, ConstUsize}, - ExtensionRegistry, - }, - ops::{DataflowOpTrait, NamedOp, Value}, - std_extensions::collections::{self, list_type, ListOp, ListValue}, - }; - use rstest::rstest; - - use crate::{ - check_emission, - custom::CodegenExtsBuilder, - emit::test::SimpleHugrConfig, - test::{llvm_ctx, TestContext}, - }; - - #[rstest] - #[case::push(ListOp::push)] - #[case::pop(ListOp::pop)] - #[case::get(ListOp::get)] - #[case::set(ListOp::set)] - #[case::insert(ListOp::insert)] - #[case::length(ListOp::length)] - fn test_collections_emission(mut llvm_ctx: TestContext, #[case] op: ListOp) { - let ext_op = collections::EXTENSION - .instantiate_extension_op( - op.name().as_ref(), - [qb_t().into()], - &collections::COLLECTIONS_REGISTRY, - ) - .unwrap(); - let es = ExtensionRegistry::new([ - collections::EXTENSION.to_owned(), - prelude::PRELUDE.to_owned(), - ]); - es.validate().unwrap(); - let hugr = SimpleHugrConfig::new() - .with_ins(ext_op.signature().input().clone()) - .with_outs(ext_op.signature().output().clone()) - .with_extensions(es) - .finish(|mut hugr_builder| { - let outputs = hugr_builder - .add_dataflow_op(ext_op, hugr_builder.input_wires()) - .unwrap() - .outputs(); - hugr_builder.finish_with_outputs(outputs).unwrap() - }); - llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); - llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_collections_extensions); - check_emission!(op.name().as_str(), hugr, llvm_ctx); - } - - #[rstest] - fn test_const_list_emmission(mut llvm_ctx: TestContext) { - let elem_ty = usize_t(); - let contents = (1..4).map(|i| Value::extension(ConstUsize::new(i))); - let es = ExtensionRegistry::new([ - collections::EXTENSION.to_owned(), - prelude::PRELUDE.to_owned(), - ]); - es.validate().unwrap(); - - let hugr = SimpleHugrConfig::new() - .with_ins(vec![]) - .with_outs(vec![list_type(elem_ty.clone())]) - .with_extensions(es) - .finish(|mut hugr_builder| { - let list = hugr_builder.add_load_value(ListValue::new(elem_ty, contents)); - hugr_builder.finish_with_outputs(vec![list]).unwrap() - }); - - llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); - llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_collections_extensions); - check_emission!("const", hugr, llvm_ctx); - } -} +pub mod list; diff --git a/hugr-llvm/src/extension/collections/list.rs b/hugr-llvm/src/extension/collections/list.rs new file mode 100644 index 000000000..fb200ca97 --- /dev/null +++ b/hugr-llvm/src/extension/collections/list.rs @@ -0,0 +1,433 @@ +use anyhow::{bail, Ok, Result}; +use hugr_core::{ + ops::{ExtensionOp, NamedOp}, + std_extensions::collections::list::{self, ListOp, ListValue}, + types::{SumType, Type, TypeArg}, + HugrView, +}; +use inkwell::values::FunctionValue; +use inkwell::{ + types::{BasicType, BasicTypeEnum, FunctionType}, + values::{BasicValueEnum, PointerValue}, + AddressSpace, +}; + +use crate::emit::func::{build_ok_or_else, build_option}; +use crate::{ + custom::{CodegenExtension, CodegenExtsBuilder}, + emit::{emit_value, func::EmitFuncContext, EmitOpArgs}, + types::TypingSession, +}; + +/// Runtime functions that implement operations on lists. +#[derive(Clone, Copy, Debug, PartialEq, Hash)] +#[non_exhaustive] +pub enum ListRtFunc { + New, + Push, + Pop, + Get, + Set, + Insert, + Length, +} + +impl ListRtFunc { + /// The signature of a given [ListRtFunc]. + /// + /// Requires a [ListCodegen] to determine the type of lists. + pub fn signature<'c>( + self, + ts: TypingSession<'c, '_>, + ccg: &(impl ListCodegen + 'c), + ) -> FunctionType<'c> { + let iwc = ts.iw_context(); + match self { + ListRtFunc::New => ccg.list_type(ts).fn_type( + &[ + iwc.i64_type().into(), // Capacity + iwc.i64_type().into(), // Single element size in bytes + iwc.i64_type().into(), // Element alignment + // Pointer to element destructor + iwc.i8_type().ptr_type(AddressSpace::default()).into(), + ], + false, + ), + ListRtFunc::Push => iwc.void_type().fn_type( + &[ + ccg.list_type(ts).into(), + iwc.i8_type().ptr_type(AddressSpace::default()).into(), + ], + false, + ), + ListRtFunc::Pop => iwc.bool_type().fn_type( + &[ + ccg.list_type(ts).into(), + iwc.i8_type().ptr_type(AddressSpace::default()).into(), + ], + false, + ), + ListRtFunc::Get | ListRtFunc::Set | ListRtFunc::Insert => iwc.bool_type().fn_type( + &[ + ccg.list_type(ts).into(), + iwc.i64_type().into(), + iwc.i8_type().ptr_type(AddressSpace::default()).into(), + ], + false, + ), + ListRtFunc::Length => iwc.i64_type().fn_type(&[ccg.list_type(ts).into()], false), + } + } + + /// Returns the extern function corresponding to this [ListRtFunc]. + /// + /// Requires a [ListCodegen] to determine the function signature. + pub fn get_extern<'c, H: HugrView>( + self, + ctx: &EmitFuncContext<'c, '_, H>, + ccg: &(impl ListCodegen + 'c), + ) -> Result> { + ctx.get_extern_func( + ccg.rt_func_name(self), + self.signature(ctx.typing_session(), ccg), + ) + } +} + +impl From for ListRtFunc { + fn from(op: ListOp) -> Self { + match op { + ListOp::get => ListRtFunc::Get, + ListOp::set => ListRtFunc::Set, + ListOp::push => ListRtFunc::Push, + ListOp::pop => ListRtFunc::Pop, + ListOp::insert => ListRtFunc::Insert, + ListOp::length => ListRtFunc::Length, + _ => todo!(), + } + } +} + +/// A helper trait for customising the lowering of [hugr_core::std_extensions::collections::list] +/// types, [hugr_core::ops::constant::CustomConst]s, and ops. +pub trait ListCodegen: Clone { + /// Return the llvm type of [hugr_core::std_extensions::collections::list::LIST_TYPENAME]. + fn list_type<'c>(&self, session: TypingSession<'c, '_>) -> BasicTypeEnum<'c> { + session + .iw_context() + .i8_type() + .ptr_type(AddressSpace::default()) + .into() + } + + /// Return the name of a given [ListRtFunc]. + fn rt_func_name(&self, func: ListRtFunc) -> String { + match func { + ListRtFunc::New => "__rt__list__new", + ListRtFunc::Push => "__rt__list__push", + ListRtFunc::Pop => "__rt__list__pop", + ListRtFunc::Get => "__rt__list__get", + ListRtFunc::Set => "__rt__list__set", + ListRtFunc::Insert => "__rt__list__insert", + ListRtFunc::Length => "__rt__list__length", + } + .into() + } +} + +/// A trivial implementation of [ListCodegen] which passes all methods +/// through to their default implementations. +#[derive(Default, Clone)] +pub struct DefaultListCodegen; + +impl ListCodegen for DefaultListCodegen {} + +#[derive(Clone, Debug, Default)] +pub struct ListCodegenExtension(CCG); + +impl ListCodegenExtension { + pub fn new(ccg: CCG) -> Self { + Self(ccg) + } +} + +impl From for ListCodegenExtension { + fn from(ccg: CCG) -> Self { + Self::new(ccg) + } +} + +impl CodegenExtension for ListCodegenExtension { + fn add_extension<'a, H: HugrView + 'a>( + self, + builder: CodegenExtsBuilder<'a, H>, + ) -> CodegenExtsBuilder<'a, H> + where + Self: 'a, + { + builder + .custom_type((list::EXTENSION_ID, list::LIST_TYPENAME), { + let ccg = self.0.clone(); + move |ts, _hugr_type| Ok(ccg.list_type(ts).as_basic_type_enum()) + }) + .custom_const::({ + let ccg = self.0.clone(); + move |ctx, k| emit_list_value(ctx, &ccg, k) + }) + .simple_extension_op::(move |ctx, args, op| { + emit_list_op(ctx, &self.0, args, op) + }) + } +} + +impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { + /// Add a [ListCodegenExtension] to the given [CodegenExtsBuilder] using `ccg` + /// as the implementation. + pub fn add_default_list_extensions(self) -> Self { + self.add_list_extensions(DefaultListCodegen) + } + + /// Add a [ListCodegenExtension] to the given [CodegenExtsBuilder] using + /// [DefaultListCodegen] as the implementation. + pub fn add_list_extensions(self, ccg: impl ListCodegen + 'a) -> Self { + self.add_extension(ListCodegenExtension::from(ccg)) + } +} + +fn emit_list_op<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + ccg: &(impl ListCodegen + 'c), + args: EmitOpArgs<'c, '_, ExtensionOp, H>, + op: ListOp, +) -> Result<()> { + let hugr_elem_ty = match args.node().args() { + [TypeArg::Type { ty }] => ty.clone(), + _ => { + bail!("Collections: invalid type args for list op"); + } + }; + let elem_ty = ctx.llvm_type(&hugr_elem_ty)?; + let func = ListRtFunc::get_extern(op.into(), ctx, ccg)?; + match op { + ListOp::push => { + let [list, elem] = args.inputs.try_into().unwrap(); + let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?; + ctx.builder() + .build_call(func, &[list.into(), elem_ptr.into()], "")?; + args.outputs.finish(ctx.builder(), vec![list])?; + } + ListOp::pop => { + let [list] = args.inputs.try_into().unwrap(); + let out_ptr = build_alloca_i8_ptr(ctx, elem_ty, None)?; + let ok = ctx + .builder() + .build_call(func, &[list.into(), out_ptr.into()], "")? + .try_as_basic_value() + .unwrap_left() + .into_int_value(); + let elem = build_load_i8_ptr(ctx, out_ptr, elem_ty)?; + let elem_opt = build_option(ctx, ok, elem, hugr_elem_ty)?; + args.outputs.finish(ctx.builder(), vec![list, elem_opt])?; + } + ListOp::get => { + let [list, idx] = args.inputs.try_into().unwrap(); + let out_ptr = build_alloca_i8_ptr(ctx, elem_ty, None)?; + let ok = ctx + .builder() + .build_call(func, &[list.into(), idx.into(), out_ptr.into()], "")? + .try_as_basic_value() + .unwrap_left() + .into_int_value(); + let elem = build_load_i8_ptr(ctx, out_ptr, elem_ty)?; + let elem_opt = build_option(ctx, ok, elem, hugr_elem_ty)?; + args.outputs.finish(ctx.builder(), vec![elem_opt])?; + } + ListOp::set => { + let [list, idx, elem] = args.inputs.try_into().unwrap(); + let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?; + let ok = ctx + .builder() + .build_call(func, &[list.into(), idx.into(), elem_ptr.into()], "")? + .try_as_basic_value() + .unwrap_left() + .into_int_value(); + let old_elem = build_load_i8_ptr(ctx, elem_ptr, elem.get_type())?; + let ok_or = + build_ok_or_else(ctx, ok, elem, hugr_elem_ty.clone(), old_elem, hugr_elem_ty)?; + args.outputs.finish(ctx.builder(), vec![list, ok_or])?; + } + ListOp::insert => { + let [list, idx, elem] = args.inputs.try_into().unwrap(); + let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?; + let ok = ctx + .builder() + .build_call(func, &[list.into(), idx.into(), elem_ptr.into()], "")? + .try_as_basic_value() + .unwrap_left() + .into_int_value(); + let unit = + ctx.llvm_sum_type(SumType::new_unary(1))? + .build_tag(ctx.builder(), 0, vec![])?; + let ok_or = build_ok_or_else(ctx, ok, unit, Type::UNIT, elem, hugr_elem_ty)?; + args.outputs.finish(ctx.builder(), vec![list, ok_or])?; + } + ListOp::length => { + let [list] = args.inputs.try_into().unwrap(); + let length = ctx + .builder() + .build_call(func, &[list.into()], "")? + .try_as_basic_value() + .unwrap_left() + .into_int_value(); + args.outputs + .finish(ctx.builder(), vec![list, length.into()])?; + } + _ => bail!("Collections: unimplemented op: {}", op.name()), + } + Ok(()) +} + +fn emit_list_value<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + ccg: &(impl ListCodegen + 'c), + val: &ListValue, +) -> Result> { + let elem_ty = ctx.llvm_type(val.get_element_type())?; + let iwc = ctx.typing_session().iw_context(); + let capacity = iwc + .i64_type() + .const_int(val.get_contents().len() as u64, false); + let elem_size = elem_ty.size_of().unwrap(); + let alignment = iwc.i64_type().const_int(8, false); + // TODO: Lookup destructor for elem_ty + let destructor = iwc.i8_type().ptr_type(AddressSpace::default()).const_null(); + let list = ctx + .builder() + .build_call( + ListRtFunc::New.get_extern(ctx, ccg)?, + &[ + capacity.into(), + elem_size.into(), + alignment.into(), + destructor.into(), + ], + "", + )? + .try_as_basic_value() + .unwrap_left(); + // Push elements onto the list + let rt_push = ListRtFunc::Push.get_extern(ctx, ccg)?; + for v in val.get_contents() { + let elem = emit_value(ctx, v)?; + let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?; + ctx.builder() + .build_call(rt_push, &[list.into(), elem_ptr.into()], "")?; + } + Ok(list) +} + +/// Helper function to allocate space on the stack for a given type. +/// +/// Optionally also stores a value at that location. +/// +/// Returns an i8 pointer to the allocated memory. +fn build_alloca_i8_ptr<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + ty: BasicTypeEnum<'c>, + value: Option>, +) -> Result> { + let builder = ctx.builder(); + let ptr = builder.build_alloca(ty, "")?; + if let Some(val) = value { + builder.build_store(ptr, val)?; + } + let i8_ptr = builder.build_pointer_cast( + ptr, + ctx.iw_context().i8_type().ptr_type(AddressSpace::default()), + "", + )?; + Ok(i8_ptr) +} + +/// Helper function to load a value from an i8 pointer. +fn build_load_i8_ptr<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + i8_ptr: PointerValue<'c>, + ty: BasicTypeEnum<'c>, +) -> Result> { + let builder = ctx.builder(); + let ptr = builder.build_pointer_cast(i8_ptr, ty.ptr_type(AddressSpace::default()), "")?; + let val = builder.build_load(ptr, "")?; + Ok(val) +} + +#[cfg(test)] +mod test { + use hugr_core::{ + builder::{Dataflow, DataflowSubContainer}, + extension::{ + prelude::{self, qb_t, usize_t, ConstUsize}, + ExtensionRegistry, + }, + ops::{DataflowOpTrait, NamedOp, Value}, + std_extensions::collections::list::{self, list_type, ListOp, ListValue}, + }; + use rstest::rstest; + + use crate::{ + check_emission, + custom::CodegenExtsBuilder, + emit::test::SimpleHugrConfig, + test::{llvm_ctx, TestContext}, + }; + + #[rstest] + #[case::push(ListOp::push)] + #[case::pop(ListOp::pop)] + #[case::get(ListOp::get)] + #[case::set(ListOp::set)] + #[case::insert(ListOp::insert)] + #[case::length(ListOp::length)] + fn test_list_emission(mut llvm_ctx: TestContext, #[case] op: ListOp) { + let ext_op = list::EXTENSION + .instantiate_extension_op(op.name().as_ref(), [qb_t().into()], &list::LIST_REGISTRY) + .unwrap(); + let es = ExtensionRegistry::new([list::EXTENSION.to_owned(), prelude::PRELUDE.to_owned()]); + es.validate().unwrap(); + let hugr = SimpleHugrConfig::new() + .with_ins(ext_op.signature().input().clone()) + .with_outs(ext_op.signature().output().clone()) + .with_extensions(es) + .finish(|mut hugr_builder| { + let outputs = hugr_builder + .add_dataflow_op(ext_op, hugr_builder.input_wires()) + .unwrap() + .outputs(); + hugr_builder.finish_with_outputs(outputs).unwrap() + }); + llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); + llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_list_extensions); + check_emission!(op.name().as_str(), hugr, llvm_ctx); + } + + #[rstest] + fn test_const_list_emmission(mut llvm_ctx: TestContext) { + let elem_ty = usize_t(); + let contents = (1..4).map(|i| Value::extension(ConstUsize::new(i))); + let es = ExtensionRegistry::new([list::EXTENSION.to_owned(), prelude::PRELUDE.to_owned()]); + es.validate().unwrap(); + + let hugr = SimpleHugrConfig::new() + .with_ins(vec![]) + .with_outs(vec![list_type(elem_ty.clone())]) + .with_extensions(es) + .finish(|mut hugr_builder| { + let list = hugr_builder.add_load_value(ListValue::new(elem_ty, contents)); + hugr_builder.finish_with_outputs(vec![list]).unwrap() + }); + + llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); + llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_list_extensions); + check_emission!("const", hugr, llvm_ctx); + } +} diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__const@llvm14.snap similarity index 94% rename from hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__const@llvm14.snap index 8ad058cf3..fb89ee00b 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__const@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections.rs +source: hugr-llvm/src/extension/collections/list.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__const@pre-mem2reg@llvm14.snap similarity index 95% rename from hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@pre-mem2reg@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__const@pre-mem2reg@llvm14.snap index 5522be9ad..dcba7479f 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__const@pre-mem2reg@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections.rs +source: hugr-llvm/src/extension/collections/list.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__get@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__get@llvm14.snap similarity index 93% rename from hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__get@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__get@llvm14.snap index 5d7d0d381..6f6b0bdf0 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__get@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__get@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections.rs +source: hugr-llvm/src/extension/collections/list.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__get@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__get@pre-mem2reg@llvm14.snap similarity index 96% rename from hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__get@pre-mem2reg@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__get@pre-mem2reg@llvm14.snap index a7eee4d03..4a7165e96 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__get@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__get@pre-mem2reg@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections.rs +source: hugr-llvm/src/extension/collections/list.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__insert@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__insert@llvm14.snap similarity index 95% rename from hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__insert@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__insert@llvm14.snap index deb84f1b5..1c2dcfb5c 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__insert@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__insert@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections.rs +source: hugr-llvm/src/extension/collections/list.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__insert@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__insert@pre-mem2reg@llvm14.snap similarity index 97% rename from hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__insert@pre-mem2reg@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__insert@pre-mem2reg@llvm14.snap index 9f92cf9a6..1657af5f2 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__insert@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__insert@pre-mem2reg@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections.rs +source: hugr-llvm/src/extension/collections/list.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__length@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__length@llvm14.snap similarity index 89% rename from hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__length@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__length@llvm14.snap index 61ddae3a3..de2fab69e 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__length@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__length@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections.rs +source: hugr-llvm/src/extension/collections/list.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__length@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__length@pre-mem2reg@llvm14.snap similarity index 95% rename from hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__length@pre-mem2reg@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__length@pre-mem2reg@llvm14.snap index e993956bb..b0ced0b59 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__length@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__length@pre-mem2reg@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections.rs +source: hugr-llvm/src/extension/collections/list.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__pop@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__pop@llvm14.snap similarity index 94% rename from hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__pop@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__pop@llvm14.snap index e011b3dfe..27e680a03 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__pop@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__pop@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections.rs +source: hugr-llvm/src/extension/collections/list.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__pop@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__pop@pre-mem2reg@llvm14.snap similarity index 96% rename from hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__pop@pre-mem2reg@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__pop@pre-mem2reg@llvm14.snap index 4b677b1a8..fbc37e239 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__pop@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__pop@pre-mem2reg@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections.rs +source: hugr-llvm/src/extension/collections/list.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__push@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__push@llvm14.snap similarity index 89% rename from hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__push@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__push@llvm14.snap index 6e9be48bc..6da39bbd7 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__push@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__push@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections.rs +source: hugr-llvm/src/extension/collections/list.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__push@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__push@pre-mem2reg@llvm14.snap similarity index 94% rename from hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__push@pre-mem2reg@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__push@pre-mem2reg@llvm14.snap index 5e88ddf5a..6d410c012 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__push@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__push@pre-mem2reg@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections.rs +source: hugr-llvm/src/extension/collections/list.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__set@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__set@llvm14.snap similarity index 95% rename from hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__set@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__set@llvm14.snap index f2b0ac21a..6c94a7ac7 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__set@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__set@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections.rs +source: hugr-llvm/src/extension/collections/list.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__set@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__set@pre-mem2reg@llvm14.snap similarity index 97% rename from hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__set@pre-mem2reg@llvm14.snap rename to hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__set@pre-mem2reg@llvm14.snap index ba89dc6cc..f191ed87e 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__set@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__list__test__set@pre-mem2reg@llvm14.snap @@ -1,5 +1,5 @@ --- -source: hugr-llvm/src/extension/collections.rs +source: hugr-llvm/src/extension/collections/list.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__convert_s@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__convert_s@llvm14.snap index 70cda0a31..4b76ebf8c 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__convert_s@llvm14.snap +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__convert_s@llvm14.snap @@ -1,5 +1,5 @@ --- -source: src/extension/conversions.rs +source: hugr-llvm/src/extension/conversions.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__convert_s@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__convert_s@pre-mem2reg@llvm14.snap index 6d1edf4a4..559d4a6b7 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__convert_s@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__convert_s@pre-mem2reg@llvm14.snap @@ -1,5 +1,5 @@ --- -source: src/extension/conversions.rs +source: hugr-llvm/src/extension/conversions.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__convert_u@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__convert_u@llvm14.snap index e98a4cc40..b1f76858d 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__convert_u@llvm14.snap +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__convert_u@llvm14.snap @@ -1,5 +1,5 @@ --- -source: src/extension/conversions.rs +source: hugr-llvm/src/extension/conversions.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__convert_u@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__convert_u@pre-mem2reg@llvm14.snap index 1825d3eaf..02dac5239 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__convert_u@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__convert_u@pre-mem2reg@llvm14.snap @@ -1,5 +1,5 @@ --- -source: src/extension/conversions.rs +source: hugr-llvm/src/extension/conversions.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__ifrombool@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__ifrombool@llvm14.snap index 2bb32d6b8..72fe903ff 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__ifrombool@llvm14.snap +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__ifrombool@llvm14.snap @@ -1,5 +1,5 @@ --- -source: src/extension/conversions.rs +source: hugr-llvm/src/extension/conversions.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__ifrombool@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__ifrombool@pre-mem2reg@llvm14.snap index 328eb15c8..ab30823d3 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__ifrombool@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__ifrombool@pre-mem2reg@llvm14.snap @@ -1,5 +1,5 @@ --- -source: src/extension/conversions.rs +source: hugr-llvm/src/extension/conversions.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__itobool@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__itobool@llvm14.snap index 70ffb0c10..6ccc2bae1 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__itobool@llvm14.snap +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__itobool@llvm14.snap @@ -1,5 +1,5 @@ --- -source: src/extension/conversions.rs +source: hugr-llvm/src/extension/conversions.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__itobool@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__itobool@pre-mem2reg@llvm14.snap index 52ea57ebd..d2b288d5a 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__itobool@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__itobool@pre-mem2reg@llvm14.snap @@ -1,5 +1,5 @@ --- -source: src/extension/conversions.rs +source: hugr-llvm/src/extension/conversions.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__trunc_s@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__trunc_s@llvm14.snap index 303eb6972..120329411 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__trunc_s@llvm14.snap +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__trunc_s@llvm14.snap @@ -1,5 +1,5 @@ --- -source: src/extension/conversions.rs +source: hugr-llvm/src/extension/conversions.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__trunc_s@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__trunc_s@pre-mem2reg@llvm14.snap index 6d2d61df0..1bc847c2d 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__trunc_s@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__trunc_s@pre-mem2reg@llvm14.snap @@ -1,5 +1,5 @@ --- -source: src/extension/conversions.rs +source: hugr-llvm/src/extension/conversions.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__trunc_u@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__trunc_u@llvm14.snap index f258e61c2..51dc2a4a1 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__trunc_u@llvm14.snap +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__trunc_u@llvm14.snap @@ -1,5 +1,5 @@ --- -source: src/extension/conversions.rs +source: hugr-llvm/src/extension/conversions.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__trunc_u@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__trunc_u@pre-mem2reg@llvm14.snap index 1f5d9aebc..ecd0899b6 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__trunc_u@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__conversions__test__trunc_u@pre-mem2reg@llvm14.snap @@ -1,5 +1,5 @@ --- -source: src/extension/conversions.rs +source: hugr-llvm/src/extension/conversions.rs expression: mod_str --- ; ModuleID = 'test_context' diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 62e9cdb9e..51bd8442f 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -171,7 +171,7 @@ fn test_big() { #[ignore = "Waiting for `unwrap` operation"] // TODO: https://github.com/CQCL/hugr/issues/1486 fn test_list_ops() -> Result<(), Box> { - use hugr_core::std_extensions::collections::{ListOp, ListValue}; + use hugr_core::std_extensions::collections::list::{ListOp, ListValue}; let base_list: Value = ListValue::new(bool_t(), [Value::false_val()]).into(); let mut build = DFGBuilder::new(Signature::new( diff --git a/hugr-passes/src/const_fold/value_handle.rs b/hugr-passes/src/const_fold/value_handle.rs index 08e7ed0f0..9eafc15e3 100644 --- a/hugr-passes/src/const_fold/value_handle.rs +++ b/hugr-passes/src/const_fold/value_handle.rs @@ -176,7 +176,7 @@ mod test { float_types::{float64_type, ConstF64}, int_types::{ConstInt, INT_TYPES}, }, - collections::ListValue, + collections::list::ListValue, }, }; diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 1ab6ff41b..56300e345 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -33,7 +33,7 @@ pub(crate) mod test { arithmetic::float_ops::EXTENSION.to_owned(), logic::EXTENSION.to_owned(), arithmetic::conversions::EXTENSION.to_owned(), - collections::EXTENSION.to_owned(), + collections::list::EXTENSION.to_owned(), ]); } } diff --git a/hugr-py/src/hugr/std/_json_defs/collections.json b/hugr-py/src/hugr/std/_json_defs/collections/list.json similarity index 91% rename from hugr-py/src/hugr/std/_json_defs/collections.json rename to hugr-py/src/hugr/std/_json_defs/collections/list.json index 2b46f5d44..b5f905add 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/list.json @@ -1,10 +1,10 @@ { "version": "0.1.0", - "name": "collections", + "name": "collections.list", "extension_reqs": [], "types": { "List": { - "extension": "collections", + "extension": "collections.list", "name": "List", "params": [ { @@ -24,7 +24,7 @@ "values": {}, "operations": { "get": { - "extension": "collections", + "extension": "collections.list", "name": "get", "description": "Lookup an element in a list by index. Panics if the index is out of bounds.", "signature": { @@ -38,7 +38,7 @@ "input": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -78,7 +78,7 @@ "binary": false }, "insert": { - "extension": "collections", + "extension": "collections.list", "name": "insert", "description": "Insert an element at index `i`. Elements at higher indices are shifted one position to the right. Panics if the index is out of bounds.", "signature": { @@ -92,7 +92,7 @@ "input": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -118,7 +118,7 @@ "output": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -159,7 +159,7 @@ "binary": false }, "length": { - "extension": "collections", + "extension": "collections.list", "name": "length", "description": "Get the length of a list", "signature": { @@ -173,7 +173,7 @@ "input": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -191,7 +191,7 @@ "output": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -215,7 +215,7 @@ "binary": false }, "pop": { - "extension": "collections", + "extension": "collections.list", "name": "pop", "description": "Pop from the back of list. Returns an optional value.", "signature": { @@ -229,7 +229,7 @@ "input": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -247,7 +247,7 @@ "output": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -282,7 +282,7 @@ "binary": false }, "push": { - "extension": "collections", + "extension": "collections.list", "name": "push", "description": "Push to the back of list", "signature": { @@ -296,7 +296,7 @@ "input": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -319,7 +319,7 @@ "output": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -340,7 +340,7 @@ "binary": false }, "set": { - "extension": "collections", + "extension": "collections.list", "name": "set", "description": "Replace the element at index `i` with value `v`.", "signature": { @@ -354,7 +354,7 @@ "input": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -380,7 +380,7 @@ "output": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { diff --git a/hugr-py/src/hugr/std/collections/__init__.py b/hugr-py/src/hugr/std/collections/__init__.py new file mode 100644 index 000000000..2c48cb67a --- /dev/null +++ b/hugr-py/src/hugr/std/collections/__init__.py @@ -0,0 +1 @@ +"""Standard extensions for collection types and operations.""" diff --git a/hugr-py/src/hugr/std/collections.py b/hugr-py/src/hugr/std/collections/list.py similarity index 95% rename from hugr-py/src/hugr/std/collections.py rename to hugr-py/src/hugr/std/collections/list.py index ba820845c..0019e6509 100644 --- a/hugr-py/src/hugr/std/collections.py +++ b/hugr-py/src/hugr/std/collections/list.py @@ -9,7 +9,7 @@ from hugr.std import _load_extension from hugr.utils import comma_sep_str -EXTENSION = _load_extension("collections") +EXTENSION = _load_extension("collections.list") def list_type(ty: tys.Type) -> tys.ExtType: diff --git a/specification/std_extensions/collections.json b/specification/std_extensions/collections/list.json similarity index 91% rename from specification/std_extensions/collections.json rename to specification/std_extensions/collections/list.json index 2b46f5d44..b5f905add 100644 --- a/specification/std_extensions/collections.json +++ b/specification/std_extensions/collections/list.json @@ -1,10 +1,10 @@ { "version": "0.1.0", - "name": "collections", + "name": "collections.list", "extension_reqs": [], "types": { "List": { - "extension": "collections", + "extension": "collections.list", "name": "List", "params": [ { @@ -24,7 +24,7 @@ "values": {}, "operations": { "get": { - "extension": "collections", + "extension": "collections.list", "name": "get", "description": "Lookup an element in a list by index. Panics if the index is out of bounds.", "signature": { @@ -38,7 +38,7 @@ "input": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -78,7 +78,7 @@ "binary": false }, "insert": { - "extension": "collections", + "extension": "collections.list", "name": "insert", "description": "Insert an element at index `i`. Elements at higher indices are shifted one position to the right. Panics if the index is out of bounds.", "signature": { @@ -92,7 +92,7 @@ "input": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -118,7 +118,7 @@ "output": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -159,7 +159,7 @@ "binary": false }, "length": { - "extension": "collections", + "extension": "collections.list", "name": "length", "description": "Get the length of a list", "signature": { @@ -173,7 +173,7 @@ "input": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -191,7 +191,7 @@ "output": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -215,7 +215,7 @@ "binary": false }, "pop": { - "extension": "collections", + "extension": "collections.list", "name": "pop", "description": "Pop from the back of list. Returns an optional value.", "signature": { @@ -229,7 +229,7 @@ "input": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -247,7 +247,7 @@ "output": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -282,7 +282,7 @@ "binary": false }, "push": { - "extension": "collections", + "extension": "collections.list", "name": "push", "description": "Push to the back of list", "signature": { @@ -296,7 +296,7 @@ "input": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -319,7 +319,7 @@ "output": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -340,7 +340,7 @@ "binary": false }, "set": { - "extension": "collections", + "extension": "collections.list", "name": "set", "description": "Replace the element at index `i` with value `v`.", "signature": { @@ -354,7 +354,7 @@ "input": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ { @@ -380,7 +380,7 @@ "output": [ { "t": "Opaque", - "extension": "collections", + "extension": "collections.list", "id": "List", "args": [ {