diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 32460ce5d..093368b60 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -7,7 +7,7 @@ use crate::{ type_param::{TypeArgVariable, TypeParam}, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg, - TypeBase, TypeEnum, + TypeBase, TypeBound, TypeEnum, }, Direction, Hugr, HugrView, IncomingPort, Node, Port, }; @@ -44,8 +44,21 @@ struct Context<'a> { bump: &'a Bump, /// Stores the terms that we have already seen to avoid duplicates. term_map: FxHashMap, model::TermId>, + /// The current scope for local variables. + /// + /// This is set to the id of the smallest enclosing node that defines a polymorphic type. + /// We use this when exporting local variables in terms. local_scope: Option, + + /// Constraints to be added to the local scope. + /// + /// When exporting a node that defines a polymorphic type, we use this field + /// to collect the constraints that need to be added to that polymorphic + /// type. Currently this is used to record `nonlinear` constraints on uses + /// of `TypeParam::Type` with a `TypeBound::Copyable` bound. + local_constraints: Vec, + /// Mapping from extension operations to their declarations. decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>, } @@ -63,6 +76,7 @@ impl<'a> Context<'a> { term_map: FxHashMap::default(), local_scope: None, decl_operations: FxHashMap::default(), + local_constraints: Vec::new(), } } @@ -173,9 +187,11 @@ impl<'a> Context<'a> { } fn with_local_scope(&mut self, node: model::NodeId, f: impl FnOnce(&mut Self) -> T) -> T { - let old_scope = self.local_scope.replace(node); + let prev_local_scope = self.local_scope.replace(node); + let prev_local_constraints = std::mem::take(&mut self.local_constraints); let result = f(self); - self.local_scope = old_scope; + self.local_scope = prev_local_scope; + self.local_constraints = prev_local_constraints; result } @@ -232,10 +248,11 @@ impl<'a> Context<'a> { OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| { let name = this.get_func_name(node).unwrap(); - let (params, signature) = this.export_poly_func_type(&func.signature); + let (params, constraints, signature) = this.export_poly_func_type(&func.signature); let decl = this.bump.alloc(model::FuncDecl { name, params, + constraints, signature, }); let extensions = this.export_ext_set(&func.signature.body().extension_reqs); @@ -247,10 +264,11 @@ impl<'a> Context<'a> { OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| { let name = this.get_func_name(node).unwrap(); - let (params, func) = this.export_poly_func_type(&func.signature); + let (params, constraints, func) = this.export_poly_func_type(&func.signature); let decl = this.bump.alloc(model::FuncDecl { name, params, + constraints, signature: func, }); model::Operation::DeclareFunc { decl } @@ -450,10 +468,11 @@ impl<'a> Context<'a> { let decl = self.with_local_scope(node, |this| { let name = this.make_qualified_name(opdef.extension(), opdef.name()); - let (params, r#type) = this.export_poly_func_type(poly_func_type); + let (params, constraints, r#type) = this.export_poly_func_type(poly_func_type); let decl = this.bump.alloc(model::OperationDecl { name, params, + constraints, r#type, }); decl @@ -671,22 +690,36 @@ impl<'a> Context<'a> { regions.into_bump_slice() } + /// Exports a polymorphic function type. + /// + /// The returned triple consists of: + /// - The static parameters of the polymorphic function type. + /// - The constraints of the polymorphic function type. + /// - The function type itself. pub fn export_poly_func_type( &mut self, t: &PolyFuncTypeBase, - ) -> (&'a [model::Param<'a>], model::TermId) { + ) -> (&'a [model::Param<'a>], &'a [model::TermId], model::TermId) { let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump); + let scope = self + .local_scope + .expect("exporting poly func type outside of local scope"); for (i, param) in t.params().iter().enumerate() { let name = self.bump.alloc_str(&i.to_string()); - let r#type = self.export_type_param(param); - let param = model::Param::Implicit { name, r#type }; + let r#type = self.export_type_param(param, Some(model::LocalRef::Index(scope, i as _))); + let param = model::Param { + name, + r#type, + sort: model::ParamSort::Implicit, + }; params.push(param) } + let constraints = self.bump.alloc_slice_copy(&self.local_constraints); let body = self.export_func_type(t.body()); - (params.into_bump_slice(), body) + (params.into_bump_slice(), constraints, body) } pub fn export_type(&mut self, t: &TypeBase) -> model::TermId { @@ -703,7 +736,6 @@ impl<'a> Context<'a> { } TypeEnum::Function(func) => self.export_func_type(func), TypeEnum::Variable(index, _) => { - // This ignores the type bound for now let node = self.local_scope.expect("local variable out of scope"); self.make_term(model::Term::Var(model::LocalRef::Index(node, *index as _))) } @@ -794,20 +826,39 @@ impl<'a> Context<'a> { self.make_term(model::Term::List { items, tail: None }) } - pub fn export_type_param(&mut self, t: &TypeParam) -> model::TermId { + /// Exports a `TypeParam` to a term. + /// + /// The `var` argument is set when the type parameter being exported is the + /// type of a parameter to a polymorphic definition. In that case we can + /// generate a `nonlinear` constraint for the type of runtime types marked as + /// `TypeBound::Copyable`. + pub fn export_type_param( + &mut self, + t: &TypeParam, + var: Option>, + ) -> model::TermId { match t { - // This ignores the type bound for now. - TypeParam::Type { .. } => self.make_term(model::Term::Type), - // This ignores the type bound for now. + TypeParam::Type { b } => { + if let (Some(var), TypeBound::Copyable) = (var, b) { + let term = self.make_term(model::Term::Var(var)); + let non_linear = self.make_term(model::Term::NonLinearConstraint { term }); + self.local_constraints.push(non_linear); + } + + self.make_term(model::Term::Type) + } + // This ignores the bound on the natural for now. TypeParam::BoundedNat { .. } => self.make_term(model::Term::NatType), TypeParam::String => self.make_term(model::Term::StrType), TypeParam::List { param } => { - let item_type = self.export_type_param(param); + let item_type = self.export_type_param(param, None); self.make_term(model::Term::ListType { item_type }) } TypeParam::Tuple { params } => { let items = self.bump.alloc_slice_fill_iter( - params.iter().map(|param| self.export_type_param(param)), + params + .iter() + .map(|param| self.export_type_param(param, None)), ); let types = self.make_term(model::Term::List { items, tail: None }); self.make_term(model::Term::ApplyFull { diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index ca338eae3..8e0eb0f4b 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -101,6 +101,7 @@ lazy_static! { NoopDef.add_to_extension(&mut prelude).unwrap(); LiftDef.add_to_extension(&mut prelude).unwrap(); array::ArrayOpDef::load_all_ops(&mut prelude).unwrap(); + array::ArrayScanDef.add_to_extension(&mut prelude).unwrap(); prelude }; /// An extension registry containing only the prelude diff --git a/hugr-core/src/extension/prelude/array.rs b/hugr-core/src/extension/prelude/array.rs index a15bf23cc..6013039d4 100644 --- a/hugr-core/src/extension/prelude/array.rs +++ b/hugr-core/src/extension/prelude/array.rs @@ -1,3 +1,6 @@ +use std::str::FromStr; + +use itertools::Itertools; use strum_macros::EnumIter; use strum_macros::EnumString; use strum_macros::IntoStaticStr; @@ -17,8 +20,10 @@ use crate::ops::ExtensionOp; use crate::ops::NamedOp; use crate::ops::OpName; use crate::type_row; +use crate::types::FuncTypeBase; use crate::types::FuncValueType; +use crate::types::RowVariable; use crate::types::TypeBound; use crate::types::Type; @@ -28,6 +33,7 @@ use crate::extension::SignatureError; use crate::types::PolyFuncTypeRV; use crate::types::type_param::TypeArg; +use crate::types::TypeRV; use crate::Extension; use super::PRELUDE_ID; @@ -46,6 +52,7 @@ pub enum ArrayOpDef { pop_left, pop_right, discard_empty, + repeat, } /// Static parameters for array operations. Includes array size. Type is part of the type scheme. @@ -118,6 +125,14 @@ impl ArrayOpDef { let standard_params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; match self { + repeat => { + let func = + Type::new_function(FuncValueType::new(type_row![], elem_ty_var.clone())); + PolyFuncTypeRV::new( + standard_params, + FuncValueType::new(vec![func], array_ty.clone()), + ) + } get => { let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; let copy_elem_ty = Type::new_var_use(1, TypeBound::Copyable); @@ -179,6 +194,10 @@ impl MakeOpDef for ArrayOpDef { fn description(&self) -> String { match self { ArrayOpDef::new_array => "Create a new array from elements", + ArrayOpDef::repeat => { + "Creates a new array whose elements are initialised by calling \ + the given function n times" + } ArrayOpDef::get => "Get an element from an array", ArrayOpDef::set => "Set an element in an array", ArrayOpDef::swap => "Swap two elements in an array", @@ -246,7 +265,7 @@ impl MakeExtensionOp for ArrayOp { ); vec![ty_arg] } - new_array | pop_left | pop_right | get | set | swap => { + new_array | repeat | pop_left | pop_right | get | set | swap => { vec![TypeArg::BoundedNat { n: self.size }, ty_arg] } } @@ -312,6 +331,192 @@ pub fn new_array_op(element_ty: Type, size: u64) -> ExtensionOp { op.to_extension_op().unwrap() } +/// Name of the operation for the combined map/fold operation +pub const ARRAY_SCAN_OP_ID: OpName = OpName::new_inline("scan"); + +/// Definition of the array scan op. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +pub struct ArrayScanDef; + +impl NamedOp for ArrayScanDef { + fn name(&self) -> OpName { + ARRAY_SCAN_OP_ID + } +} + +impl FromStr for ArrayScanDef { + type Err = (); + + fn from_str(s: &str) -> Result { + if s == ArrayScanDef.name() { + Ok(Self) + } else { + Err(()) + } + } +} + +impl ArrayScanDef { + /// To avoid recursion when defining the extension, take the type definition as an argument. + fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { + // array, (T1, *A -> T2, *A), -> array, *A + let params = vec![ + TypeParam::max_nat(), + TypeBound::Any.into(), + TypeBound::Any.into(), + TypeParam::new_list(TypeBound::Any), + ]; + let n = TypeArg::new_var_use(0, TypeParam::max_nat()); + let t1 = Type::new_var_use(1, TypeBound::Any); + let t2 = Type::new_var_use(2, TypeBound::Any); + let s = TypeRV::new_row_var_use(3, TypeBound::Any); + PolyFuncTypeRV::new( + params, + FuncTypeBase::::new( + vec![ + instantiate(array_def, n.clone(), t1.clone()).into(), + Type::new_function(FuncTypeBase::::new( + vec![t1.into(), s.clone()], + vec![t2.clone().into(), s.clone()], + )) + .into(), + s.clone(), + ], + vec![instantiate(array_def, n, t2).into(), s], + ), + ) + .into() + } +} + +impl MakeOpDef for ArrayScanDef { + fn from_def(op_def: &OpDef) -> Result + where + Self: Sized, + { + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + } + + fn signature(&self) -> SignatureFunc { + self.signature_from_def(array_type_def()) + } + + fn extension(&self) -> ExtensionId { + PRELUDE_ID + } + + fn description(&self) -> String { + "A combination of map and foldl. Applies a function to each element \ + of the array with an accumulator that is passed through from start to \ + finish. Returns the resulting array and the final state of the \ + accumulator." + .into() + } + + /// Add an operation implemented as a [MakeOpDef], which can provide the data + /// required to define an [OpDef], to an extension. + // + // This method is re-defined here since we need to pass the array type def while + // computing the signature, to avoid recursive loops initializing the extension. + fn add_to_extension( + &self, + extension: &mut Extension, + ) -> Result<(), crate::extension::ExtensionBuildError> { + let sig = self.signature_from_def(extension.get_type(ARRAY_TYPE_NAME).unwrap()); + let def = extension.add_op(self.name(), self.description(), sig)?; + + self.post_opdef(def); + + Ok(()) + } +} + +/// Definition of the array scan op. +#[derive(Clone, Debug, PartialEq)] +pub struct ArrayScan { + /// The element type of the input array. + src_ty: Type, + /// The target element type of the output array. + tgt_ty: Type, + /// The accumulator types. + acc_tys: Vec, + /// Size of the array. + size: u64, +} + +impl ArrayScan { + fn new(src_ty: Type, tgt_ty: Type, acc_tys: Vec, size: u64) -> Self { + ArrayScan { + src_ty, + tgt_ty, + acc_tys, + size, + } + } +} + +impl NamedOp for ArrayScan { + fn name(&self) -> OpName { + ARRAY_SCAN_OP_ID + } +} + +impl MakeExtensionOp for ArrayScan { + fn from_extension_op(ext_op: &ExtensionOp) -> Result + where + Self: Sized, + { + let def = ArrayScanDef::from_def(ext_op.def())?; + def.instantiate(ext_op.args()) + } + + fn type_args(&self) -> Vec { + vec![ + TypeArg::BoundedNat { n: self.size }, + self.src_ty.clone().into(), + self.tgt_ty.clone().into(), + TypeArg::Sequence { + elems: self.acc_tys.clone().into_iter().map_into().collect(), + }, + ] + } +} + +impl MakeRegisteredOp for ArrayScan { + fn extension_id(&self) -> ExtensionId { + PRELUDE_ID + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r crate::extension::ExtensionRegistry { + &PRELUDE_REGISTRY + } +} + +impl HasDef for ArrayScan { + type Def = ArrayScanDef; +} + +impl HasConcrete for ArrayScanDef { + type Concrete = ArrayScan; + + fn instantiate(&self, type_args: &[TypeArg]) -> Result { + match type_args { + [TypeArg::BoundedNat { n }, TypeArg::Type { ty: src_ty }, TypeArg::Type { ty: tgt_ty }, TypeArg::Sequence { elems: acc_tys }] => + { + let acc_tys: Result<_, OpLoadError> = acc_tys + .iter() + .map(|acc_ty| match acc_ty { + TypeArg::Type { ty } => Ok(ty.clone()), + _ => Err(SignatureError::InvalidTypeArgs.into()), + }) + .collect(); + Ok(ArrayScan::new(src_ty.clone(), tgt_ty.clone(), acc_tys?, *n)) + } + _ => Err(SignatureError::InvalidTypeArgs.into()), + } + } +} + #[cfg(test)] mod tests { use strum::IntoEnumIterator; @@ -320,6 +525,7 @@ mod tests { builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr}, extension::prelude::{BOOL_T, QB_T}, ops::{OpTrait, OpType}, + types::Signature, }; use super::*; @@ -459,4 +665,89 @@ mod tests { ) ); } + + #[test] + fn test_repeat() { + let size = 2; + let element_ty = QB_T; + let op = ArrayOpDef::repeat.to_concrete(element_ty.clone(), size); + + let optype: OpType = op.into(); + + let sig = optype.dataflow_signature().unwrap(); + + assert_eq!( + sig.io(), + ( + &vec![Type::new_function(Signature::new(vec![], vec![QB_T]))].into(), + &vec![array_type(size, element_ty.clone())].into(), + ) + ); + } + + #[test] + fn test_scan_def() { + let op = ArrayScan::new(BOOL_T, QB_T, vec![USIZE_T], 2); + let optype: OpType = op.clone().into(); + let new_op: ArrayScan = optype.cast().unwrap(); + assert_eq!(new_op, op); + } + + #[test] + fn test_scan_map() { + let size = 2; + let src_ty = QB_T; + let tgt_ty = BOOL_T; + + let op = ArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size); + let optype: OpType = op.into(); + let sig = optype.dataflow_signature().unwrap(); + + assert_eq!( + sig.io(), + ( + &vec![ + array_type(size, src_ty.clone()), + Type::new_function(Signature::new(vec![src_ty], vec![tgt_ty.clone()])) + ] + .into(), + &vec![array_type(size, tgt_ty)].into(), + ) + ); + } + + #[test] + fn test_scan_accs() { + let size = 2; + let src_ty = QB_T; + let tgt_ty = BOOL_T; + let acc_ty1 = USIZE_T; + let acc_ty2 = QB_T; + + let op = ArrayScan::new( + src_ty.clone(), + tgt_ty.clone(), + vec![acc_ty1.clone(), acc_ty2.clone()], + size, + ); + let optype: OpType = op.into(); + let sig = optype.dataflow_signature().unwrap(); + + assert_eq!( + sig.io(), + ( + &vec![ + array_type(size, src_ty.clone()), + Type::new_function(Signature::new( + vec![src_ty, acc_ty1.clone(), acc_ty2.clone()], + vec![tgt_ty.clone(), acc_ty1.clone(), acc_ty2.clone()] + )), + acc_ty1.clone(), + acc_ty2.clone() + ] + .into(), + &vec![array_type(size, tgt_ty), acc_ty1, acc_ty2].into(), + ) + ); + } } diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index e6a53cb2f..7619ad44a 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -115,8 +115,8 @@ struct Context<'a> { /// A map from `NodeId` to the imported `Node`. nodes: FxHashMap, - /// The types of the local variables that are currently in scope. - local_variables: FxIndexMap<&'a str, model::TermId>, + /// The local variables that are currently in scope. + local_variables: FxIndexMap<&'a str, LocalVar>, custom_name_cache: FxHashMap<&'a str, (ExtensionId, SmolStr)>, } @@ -155,20 +155,20 @@ impl<'a> Context<'a> { .ok_or_else(|| model::ModelError::RegionNotFound(region_id).into()) } - /// Looks up a [`LocalRef`] within the current scope and returns its index and type. + /// Looks up a [`LocalRef`] within the current scope. fn resolve_local_ref( &self, local_ref: &model::LocalRef, - ) -> Result<(usize, model::TermId), ImportError> { + ) -> Result<(usize, LocalVar), ImportError> { let term = match local_ref { model::LocalRef::Index(_, index) => self .local_variables .get_index(*index as usize) - .map(|(_, term)| (*index as usize, *term)), + .map(|(_, v)| (*index as usize, *v)), model::LocalRef::Named(name) => self .local_variables .get_full(name) - .map(|(index, _, term)| (index, *term)), + .map(|(index, _, v)| (index, *v)), }; term.ok_or_else(|| model::ModelError::InvalidLocal(local_ref.to_string()).into()) @@ -898,41 +898,49 @@ impl<'a> Context<'a> { self.with_local_socpe(|ctx| { let mut imported_params = Vec::with_capacity(decl.params.len()); - for param in decl.params { - // TODO: `PolyFuncType` should be able to handle constraints - // and distinguish between implicit and explicit parameters. - match param { - model::Param::Implicit { name, r#type } => { - imported_params.push(ctx.import_type_param(*r#type)?); - ctx.local_variables.insert(name, *r#type); - } - model::Param::Explicit { name, r#type } => { - imported_params.push(ctx.import_type_param(*r#type)?); - ctx.local_variables.insert(name, *r#type); - } - model::Param::Constraint { constraint: _ } => { - return Err(error_unsupported!("constraints")); + ctx.local_variables.extend( + decl.params + .iter() + .map(|param| (param.name, LocalVar::new(param.r#type))), + ); + + for constraint in decl.constraints { + match ctx.get_term(*constraint)? { + model::Term::NonLinearConstraint { term } => { + let model::Term::Var(var) = ctx.get_term(*term)? else { + return Err(error_unsupported!( + "constraint on term that is not a variable" + )); + }; + + let var = ctx.resolve_local_ref(var)?.0; + ctx.local_variables[var].bound = TypeBound::Copyable; } + _ => return Err(error_unsupported!("constraint other than copy or discard")), } } + for (index, param) in decl.params.iter().enumerate() { + // NOTE: `PolyFuncType` only has explicit type parameters at present. + let bound = ctx.local_variables[index].bound; + imported_params.push(ctx.import_type_param(param.r#type, bound)?); + } + let body = ctx.import_func_type::(decl.signature)?; in_scope(ctx, PolyFuncTypeBase::new(imported_params, body)) }) } /// Import a [`TypeParam`] from a term that represents a static type. - fn import_type_param(&mut self, term_id: model::TermId) -> Result { + fn import_type_param( + &mut self, + term_id: model::TermId, + bound: TypeBound, + ) -> Result { match self.get_term(term_id)? { model::Term::Wildcard => Err(error_uninferred!("wildcard")), - model::Term::Type => { - // As part of the migration from `TypeBound`s to constraints, we pretend that all - // `TypeBound`s are copyable. - Ok(TypeParam::Type { - b: TypeBound::Copyable, - }) - } + model::Term::Type => Ok(TypeParam::Type { b: bound }), model::Term::StaticType => Err(error_unsupported!("`type` as `TypeParam`")), model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeParam`")), @@ -944,7 +952,9 @@ impl<'a> Context<'a> { model::Term::FuncType { .. } => Err(error_unsupported!("`(fn ...)` as `TypeParam`")), model::Term::ListType { item_type } => { - let param = Box::new(self.import_type_param(*item_type)?); + // At present `hugr-model` has no way to express that the item + // type of a list must be copyable. Therefore we import it as `Any`. + let param = Box::new(self.import_type_param(*item_type, TypeBound::Any)?); Ok(TypeParam::List { param }) } @@ -958,7 +968,10 @@ impl<'a> Context<'a> { | model::Term::List { .. } | model::Term::ExtSet { .. } | model::Term::Adt { .. } - | model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Control { .. } + | model::Term::NonLinearConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } model::Term::ControlType => { Err(error_unsupported!("type of control types as `TypeParam`")) @@ -966,7 +979,7 @@ impl<'a> Context<'a> { } } - /// Import a `TypeArg` froma term that represents a static type or value. + /// Import a `TypeArg` from a term that represents a static type or value. fn import_type_arg(&mut self, term_id: model::TermId) -> Result { match self.get_term(term_id)? { model::Term::Wildcard => Err(error_uninferred!("wildcard")), @@ -975,8 +988,8 @@ impl<'a> Context<'a> { } model::Term::Var(var) => { - let (index, var_type) = self.resolve_local_ref(var)?; - let decl = self.import_type_param(var_type)?; + let (index, var) = self.resolve_local_ref(var)?; + let decl = self.import_type_param(var.r#type, var.bound)?; Ok(TypeArg::new_var_use(index, decl)) } @@ -1014,7 +1027,10 @@ impl<'a> Context<'a> { model::Term::FuncType { .. } | model::Term::Adt { .. } - | model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Control { .. } + | model::Term::NonLinearConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } } } @@ -1115,7 +1131,10 @@ impl<'a> Context<'a> { | model::Term::List { .. } | model::Term::Control { .. } | model::Term::ControlType - | model::Term::Nat(_) => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Nat(_) + | model::Term::NonLinearConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } } } @@ -1291,3 +1310,21 @@ impl<'a> Names<'a> { Ok(Self { items }) } } + +/// Information about a local variable. +#[derive(Debug, Clone, Copy)] +struct LocalVar { + /// The type of the variable. + r#type: model::TermId, + /// The type bound of the variable. + bound: TypeBound, +} + +impl LocalVar { + pub fn new(r#type: model::TermId) -> Self { + Self { + r#type, + bound: TypeBound::Any, + } + } +} diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs index 611eda660..d9ef0d2c9 100644 --- a/hugr-core/tests/model.rs +++ b/hugr-core/tests/model.rs @@ -58,3 +58,10 @@ pub fn test_roundtrip_params() { "../../hugr-model/tests/fixtures/model-params.edn" ))); } + +#[test] +pub fn test_roundtrip_constraints() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-constraints.edn" + ))); +} diff --git a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap new file mode 100644 index 000000000..f085c4785 --- /dev/null +++ b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap @@ -0,0 +1,16 @@ +--- +source: hugr-core/tests/model.rs +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-constraints.edn\"))" +--- +(hugr 0) + +(declare-func array.replicate + (forall ?0 type) + (forall ?1 nat) + (where (nonlinear ?0)) + [?0] [(@ array.Array ?0 ?1)] (ext)) + +(declare-func array.copy + (forall ?0 type) + (where (nonlinear ?0)) + [(@ array.Array ?0)] [(@ array.Array ?0) (@ array.Array ?0)] (ext)) diff --git a/hugr-model/capnp/hugr-v0.capnp b/hugr-model/capnp/hugr-v0.capnp index 95db81205..94341beba 100644 --- a/hugr-model/capnp/hugr-v0.capnp +++ b/hugr-model/capnp/hugr-v0.capnp @@ -56,13 +56,15 @@ struct Operation { struct FuncDefn { name @0 :Text; params @1 :List(Param); - signature @2 :TermId; + constraints @2 :List(TermId); + signature @3 :TermId; } struct FuncDecl { name @0 :Text; params @1 :List(Param); - signature @2 :TermId; + constraints @2 :List(TermId); + signature @3 :TermId; } struct AliasDefn { @@ -81,13 +83,15 @@ struct Operation { struct ConstructorDecl { name @0 :Text; params @1 :List(Param); - type @2 :TermId; + constraints @2 :List(TermId); + type @3 :TermId; } struct OperationDecl { name @0 :Text; params @1 :List(Param); - type @2 :TermId; + constraints @2 :List(TermId); + type @3 :TermId; } } @@ -157,6 +161,7 @@ struct Term { funcType @17 :FuncType; control @18 :TermId; controlType @19 :Void; + nonLinearConstraint @20 :TermId; } struct Apply { @@ -187,19 +192,12 @@ struct Term { } struct Param { - union { - implicit @0 :Implicit; - explicit @1 :Explicit; - constraint @2 :TermId; - } - - struct Implicit { - name @0 :Text; - type @1 :TermId; - } + name @0 :Text; + type @1 :TermId; + sort @2 :ParamSort; +} - struct Explicit { - name @0 :Text; - type @1 :TermId; - } +enum ParamSort { + implicit @0; + explicit @1; } diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index 681bd4ea9..5381a7dc8 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -140,10 +140,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let signature = model::TermId(reader.get_signature()); let decl = bump.alloc(model::FuncDecl { name, params, + constraints, signature, }); model::Operation::DefineFunc { decl } @@ -152,10 +154,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let signature = model::TermId(reader.get_signature()); let decl = bump.alloc(model::FuncDecl { name, params, + constraints, signature, }); model::Operation::DeclareFunc { decl } @@ -189,10 +193,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let r#type = model::TermId(reader.get_type()); let decl = bump.alloc(model::ConstructorDecl { name, params, + constraints, r#type, }); model::Operation::DeclareConstructor { decl } @@ -201,10 +207,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let r#type = model::TermId(reader.get_type()); let decl = bump.alloc(model::OperationDecl { name, params, + constraints, r#type, }); model::Operation::DeclareOperation { decl } @@ -332,6 +340,10 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult Which::Control(values) => model::Term::Control { values: model::TermId(values), }, + + Which::NonLinearConstraint(term) => model::Term::NonLinearConstraint { + term: model::TermId(term), + }, }) } @@ -348,23 +360,13 @@ fn read_param<'a>( bump: &'a Bump, reader: hugr_capnp::param::Reader, ) -> ReadResult> { - use hugr_capnp::param::Which; - Ok(match reader.which()? { - Which::Implicit(reader) => { - let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); - let r#type = model::TermId(reader.get_type()); - model::Param::Implicit { name, r#type } - } - Which::Explicit(reader) => { - let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); - let r#type = model::TermId(reader.get_type()); - model::Param::Explicit { name, r#type } - } - Which::Constraint(constraint) => { - let constraint = model::TermId(constraint); - model::Param::Constraint { constraint } - } - }) + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let r#type = model::TermId(reader.get_type()); + + let sort = match reader.get_sort()? { + hugr_capnp::ParamSort::Implicit => model::ParamSort::Implicit, + hugr_capnp::ParamSort::Explicit => model::ParamSort::Explicit, + }; + + Ok(model::Param { name, r#type, sort }) } diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index a4b64d646..f3a0a14d2 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -60,12 +60,14 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode let mut builder = builder.init_func_defn(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_signature(decl.signature.0); } model::Operation::DeclareFunc { decl } => { let mut builder = builder.init_func_decl(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_signature(decl.signature.0); } @@ -87,12 +89,14 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode let mut builder = builder.init_constructor_decl(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_type(decl.r#type.0); } model::Operation::DeclareOperation { decl } => { let mut builder = builder.init_operation_decl(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_type(decl.r#type.0); } @@ -101,19 +105,12 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode } fn write_param(mut builder: hugr_capnp::param::Builder, param: &model::Param) { - match param { - model::Param::Implicit { name, r#type } => { - let mut builder = builder.init_implicit(); - builder.set_name(name); - builder.set_type(r#type.0); - } - model::Param::Explicit { name, r#type } => { - let mut builder = builder.init_explicit(); - builder.set_name(name); - builder.set_type(r#type.0); - } - model::Param::Constraint { constraint } => builder.set_constraint(constraint.0), - } + builder.set_name(param.name); + builder.set_type(param.r#type.0); + builder.set_sort(match param.sort { + model::ParamSort::Implicit => hugr_capnp::ParamSort::Implicit, + model::ParamSort::Explicit => hugr_capnp::ParamSort::Explicit, + }); } fn write_global_ref(mut builder: hugr_capnp::global_ref::Builder, global_ref: &model::GlobalRef) { @@ -212,5 +209,9 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { builder.set_outputs(outputs.0); builder.set_extensions(extensions.0); } + + model::Term::NonLinearConstraint { term } => { + builder.set_non_linear_constraint(term.0); + } } } diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index cb8713b32..16c7cb6c6 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -397,6 +397,8 @@ pub struct FuncDecl<'a> { pub name: &'a str, /// The static parameters of the function. pub params: &'a [Param<'a>], + /// The constraints on the static parameters. + pub constraints: &'a [TermId], /// The signature of the function. pub signature: TermId, } @@ -419,6 +421,8 @@ pub struct ConstructorDecl<'a> { pub name: &'a str, /// The static parameters of the constructor. pub params: &'a [Param<'a>], + /// The constraints on the static parameters. + pub constraints: &'a [TermId], /// The type of the constructed term. pub r#type: TermId, } @@ -430,6 +434,8 @@ pub struct OperationDecl<'a> { pub name: &'a str, /// The static parameters of the operation. pub params: &'a [Param<'a>], + /// The constraints on the static parameters. + pub constraints: &'a [TermId], /// The type of the operation. This must be a function type. pub r#type: TermId, } @@ -662,6 +668,12 @@ pub enum Term<'a> { /// /// `ctrl : static` ControlType, + + /// Constraint that requires a runtime type to be copyable and discardable. + NonLinearConstraint { + /// The runtime type that must be copyable and discardable. + term: TermId, + }, } /// A parameter to a function or alias. @@ -669,33 +681,23 @@ pub enum Term<'a> { /// Parameter names must be unique within a parameter list. /// Implicit and explicit parameters share a namespace. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum Param<'a> { - /// An implicit parameter that should be inferred, unless a full application form is used +pub struct Param<'a> { + /// The name of the parameter. + pub name: &'a str, + /// The type of the parameter. + pub r#type: TermId, + /// The sort of the parameter (implicit or explicit). + pub sort: ParamSort, +} + +/// The sort of a parameter (implicit or explicit). +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ParamSort { + /// The parameter is implicit and should be inferred, unless a full application form is used /// (see [`Term::ApplyFull`] and [`Operation::CustomFull`]). - Implicit { - /// The name of the parameter. - name: &'a str, - /// The type of the parameter. - /// - /// This must be a term of type `static`. - r#type: TermId, - }, - /// An explicit parameter that should always be provided. - Explicit { - /// The name of the parameter. - name: &'a str, - /// The type of the parameter. - /// - /// This must be a term of type `static`. - r#type: TermId, - }, - /// A constraint that should be satisfied by other parameters in a parameter list. - Constraint { - /// The constraint to be satisfied. - /// - /// This must be a term of type `constraint`. - constraint: TermId, - }, + Implicit, + /// The parameter is explicit and should always be provided. + Explicit, } /// Errors that can occur when traversing and interpreting the model. diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index 132d78567..d05e3d774 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -56,16 +56,16 @@ node_tag = { "(" ~ "tag" ~ tag ~ port_lists? ~ signature? ~ meta* node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } signature = { "(" ~ "signature" ~ term ~ ")" } -func_header = { symbol ~ param* ~ term ~ term ~ term } +func_header = { symbol ~ param* ~ where_clause* ~ term ~ term ~ term } alias_header = { symbol ~ param* ~ term } -ctr_header = { symbol ~ param* ~ term } -operation_header = { symbol ~ param* ~ term } +ctr_header = { symbol ~ param* ~ where_clause* ~ term } +operation_header = { symbol ~ param* ~ where_clause* ~ term } -param = { param_implicit | param_explicit | param_constraint } +param = { param_implicit | param_explicit } -param_implicit = { "(" ~ "forall" ~ term_var ~ term ~ ")" } -param_explicit = { "(" ~ "param" ~ term_var ~ term ~ ")" } -param_constraint = { "(" ~ "where" ~ term ~ ")" } +param_implicit = { "(" ~ "forall" ~ term_var ~ term ~ ")" } +param_explicit = { "(" ~ "param" ~ term_var ~ term ~ ")" } +where_clause = { "(" ~ "where" ~ term ~ ")" } region = { region_dfg | region_cfg } region_dfg = { "(" ~ "dfg" ~ port_lists? ~ signature? ~ meta* ~ node* ~ ")" } @@ -92,6 +92,7 @@ term = { | term_ctrl_type | term_apply_full | term_apply + | term_non_linear } term_wildcard = { "_" } @@ -114,3 +115,4 @@ term_adt = { "(" ~ "adt" ~ term ~ ")" } term_func_type = { "(" ~ "fn" ~ term ~ term ~ term ~ ")" } term_ctrl = { "(" ~ "ctrl" ~ term ~ ")" } term_ctrl_type = { "ctrl" } +term_non_linear = { "(" ~ "nonlinear" ~ term ~ ")" } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index fa486454b..370dbeac0 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -7,7 +7,7 @@ use thiserror::Error; use crate::v0::{ AliasDecl, ConstructorDecl, FuncDecl, GlobalRef, LinkRef, LocalRef, MetaItem, Module, Node, - NodeId, Operation, OperationDecl, Param, Region, RegionId, RegionKind, Term, TermId, + NodeId, Operation, OperationDecl, Param, ParamSort, Region, RegionId, RegionKind, Term, TermId, }; mod pest_parser { @@ -209,6 +209,11 @@ impl<'a> ParseContext<'a> { Term::Control { values } } + Rule::term_non_linear => { + let term = self.parse_term(inner.next().unwrap())?; + Term::NonLinearConstraint { term } + } + r => unreachable!("term: {:?}", r), }; @@ -544,6 +549,7 @@ impl<'a> ParseContext<'a> { let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; + let constraints = self.parse_constraints(&mut inner)?; let inputs = self.parse_term(inner.next().unwrap())?; let outputs = self.parse_term(inner.next().unwrap())?; @@ -559,6 +565,7 @@ impl<'a> ParseContext<'a> { Ok(self.bump.alloc(FuncDecl { name, params, + constraints, signature: func, })) } @@ -584,11 +591,13 @@ impl<'a> ParseContext<'a> { let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; + let constraints = self.parse_constraints(&mut inner)?; let r#type = self.parse_term(inner.next().unwrap())?; Ok(self.bump.alloc(ConstructorDecl { name, params, + constraints, r#type, })) } @@ -599,11 +608,13 @@ impl<'a> ParseContext<'a> { let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; + let constraints = self.parse_constraints(&mut inner)?; let r#type = self.parse_term(inner.next().unwrap())?; Ok(self.bump.alloc(OperationDecl { name, params, + constraints, r#type, })) } @@ -619,18 +630,21 @@ impl<'a> ParseContext<'a> { let mut inner = param.into_inner(); let name = &inner.next().unwrap().as_str()[1..]; let r#type = self.parse_term(inner.next().unwrap())?; - Param::Implicit { name, r#type } + Param { + name, + r#type, + sort: ParamSort::Implicit, + } } Rule::param_explicit => { let mut inner = param.into_inner(); let name = &inner.next().unwrap().as_str()[1..]; let r#type = self.parse_term(inner.next().unwrap())?; - Param::Explicit { name, r#type } - } - Rule::param_constraint => { - let mut inner = param.into_inner(); - let constraint = self.parse_term(inner.next().unwrap())?; - Param::Constraint { constraint } + Param { + name, + r#type, + sort: ParamSort::Explicit, + } } _ => unreachable!(), }; @@ -641,6 +655,17 @@ impl<'a> ParseContext<'a> { Ok(self.bump.alloc_slice_copy(¶ms)) } + fn parse_constraints(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [TermId]> { + let mut constraints = Vec::new(); + + for pair in filter_rule(pairs, Rule::where_clause) { + let constraint = self.parse_term(pair.into_inner().next().unwrap())?; + constraints.push(constraint); + } + + Ok(self.bump.alloc_slice_copy(&constraints)) + } + fn parse_signature(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult> { let Some(Rule::signature) = pairs.peek().map(|p| p.as_rule()) else { return Ok(None); diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index 01b9d7195..512f6d1e4 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -2,8 +2,8 @@ use pretty::{Arena, DocAllocator, RefDoc}; use std::borrow::Cow; use crate::v0::{ - GlobalRef, LinkRef, LocalRef, MetaItem, ModelError, Module, NodeId, Operation, Param, RegionId, - RegionKind, Term, TermId, + GlobalRef, LinkRef, LocalRef, MetaItem, ModelError, Module, NodeId, Operation, Param, + ParamSort, RegionId, RegionKind, Term, TermId, }; type PrintError = ModelError; @@ -122,15 +122,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { f: impl FnOnce(&mut Self) -> PrintResult, ) -> PrintResult { let locals = std::mem::take(&mut self.locals); - - for param in params { - match param { - Param::Implicit { name, .. } => self.locals.push(name), - Param::Explicit { name, .. } => self.locals.push(name), - Param::Constraint { .. } => {} - } - } - + self.locals.extend(params.iter().map(|param| param.name)); let result = f(self); self.locals = locals; result @@ -178,9 +170,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; match self.module.get_term(decl.signature) { Some(Term::FuncType { @@ -208,9 +199,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; match self.module.get_term(decl.signature) { Some(Term::FuncType { @@ -303,9 +293,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; this.print_term(decl.r#type)?; this.print_term(*value)?; @@ -318,9 +306,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; this.print_term(decl.r#type)?; this.print_meta(node_data.meta)?; @@ -333,9 +319,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; this.print_term(decl.r#type)?; this.print_meta(node_data.meta)?; @@ -348,9 +333,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; this.print_term(decl.r#type)?; this.print_meta(node_data.meta)?; @@ -384,10 +368,9 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { } fn print_regions(&mut self, regions: &'a [RegionId]) -> PrintResult<()> { - for region in regions { - self.print_region(*region)?; - } - Ok(()) + regions + .iter() + .try_for_each(|region| self.print_region(*region)) } fn print_region(&mut self, region: RegionId) -> PrintResult<()> { @@ -422,11 +405,10 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { .get_region(region) .ok_or(PrintError::RegionNotFound(region))?; - for node_id in region_data.children { - self.print_node(*node_id)?; - } - - Ok(()) + region_data + .children + .iter() + .try_for_each(|node_id| self.print_node(*node_id)) } fn print_port_lists( @@ -460,25 +442,33 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { } } + fn print_params(&mut self, params: &'a [Param<'a>]) -> PrintResult<()> { + params.iter().try_for_each(|param| self.print_param(*param)) + } + fn print_param(&mut self, param: Param<'a>) -> PrintResult<()> { - self.print_parens(|this| match param { - Param::Implicit { name, r#type } => { - this.print_text("forall"); - this.print_text(format!("?{}", name)); - this.print_term(r#type) - } - Param::Explicit { name, r#type } => { - this.print_text("param"); - this.print_text(format!("?{}", name)); - this.print_term(r#type) - } - Param::Constraint { constraint } => { - this.print_text("where"); - this.print_term(constraint) - } + self.print_parens(|this| { + match param.sort { + ParamSort::Implicit => this.print_text("forall"), + ParamSort::Explicit => this.print_text("param"), + }; + + this.print_text(format!("?{}", param.name)); + this.print_term(param.r#type) }) } + fn print_constraints(&mut self, terms: &'a [TermId]) -> PrintResult<()> { + for term in terms { + self.print_parens(|this| { + this.print_text("where"); + this.print_term(*term) + })?; + } + + Ok(()) + } + fn print_term(&mut self, term_id: TermId) -> PrintResult<()> { let term_data = self .module @@ -598,6 +588,10 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { self.print_text("ctrl"); Ok(()) } + Term::NonLinearConstraint { term } => self.print_parens(|this| { + this.print_text("nonlinear"); + this.print_term(*term) + }), } } diff --git a/hugr-model/tests/binary.rs b/hugr-model/tests/binary.rs index 043061677..93955fe6e 100644 --- a/hugr-model/tests/binary.rs +++ b/hugr-model/tests/binary.rs @@ -51,3 +51,8 @@ pub fn test_params() { pub fn test_decl_exts() { binary_roundtrip(include_str!("fixtures/model-decl-exts.edn")); } + +#[test] +pub fn test_constraints() { + binary_roundtrip(include_str!("fixtures/model-constraints.edn")); +} diff --git a/hugr-model/tests/fixtures/model-constraints.edn b/hugr-model/tests/fixtures/model-constraints.edn new file mode 100644 index 000000000..5db6b9886 --- /dev/null +++ b/hugr-model/tests/fixtures/model-constraints.edn @@ -0,0 +1,13 @@ +(hugr 0) + +(declare-func array.replicate + (forall ?t type) + (forall ?n nat) + (where (nonlinear ?t)) + [?t] [(@ array.Array ?t ?n)] + (ext)) + +(declare-func array.copy + (forall ?t type) + (where (nonlinear ?t)) + [(@ array.Array ?t)] [(@ array.Array ?t) (@ array.Array ?t)] (ext)) diff --git a/hugr-py/src/hugr/std/_json_defs/prelude.json b/hugr-py/src/hugr/std/_json_defs/prelude.json index 014ba3ede..b48692b39 100644 --- a/hugr-py/src/hugr/std/_json_defs/prelude.json +++ b/hugr-py/src/hugr/std/_json_defs/prelude.json @@ -418,6 +418,189 @@ }, "binary": false }, + "repeat": { + "extension": "prelude", + "name": "repeat", + "description": "Creates a new array whose elements are initialised by calling the given function n times", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "G", + "input": [], + "output": [ + { + "t": "V", + "i": 1, + "b": "A" + } + ], + "extension_reqs": [] + } + ], + "output": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, + "scan": { + "extension": "prelude", + "name": "scan", + "description": "A combination of map and foldl. Applies a function to each element of the array with an accumulator that is passed through from start to finish. Returns the resulting array and the final state of the accumulator.", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + }, + { + "tp": "Type", + "b": "A" + }, + { + "tp": "List", + "param": { + "tp": "Type", + "b": "A" + } + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "G", + "input": [ + { + "t": "V", + "i": 1, + "b": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "output": [ + { + "t": "V", + "i": 2, + "b": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "extension_reqs": [] + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 2, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, "set": { "extension": "prelude", "name": "set", diff --git a/specification/std_extensions/prelude.json b/specification/std_extensions/prelude.json index 014ba3ede..b48692b39 100644 --- a/specification/std_extensions/prelude.json +++ b/specification/std_extensions/prelude.json @@ -418,6 +418,189 @@ }, "binary": false }, + "repeat": { + "extension": "prelude", + "name": "repeat", + "description": "Creates a new array whose elements are initialised by calling the given function n times", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "G", + "input": [], + "output": [ + { + "t": "V", + "i": 1, + "b": "A" + } + ], + "extension_reqs": [] + } + ], + "output": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, + "scan": { + "extension": "prelude", + "name": "scan", + "description": "A combination of map and foldl. Applies a function to each element of the array with an accumulator that is passed through from start to finish. Returns the resulting array and the final state of the accumulator.", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + }, + { + "tp": "Type", + "b": "A" + }, + { + "tp": "List", + "param": { + "tp": "Type", + "b": "A" + } + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "G", + "input": [ + { + "t": "V", + "i": 1, + "b": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "output": [ + { + "t": "V", + "i": 2, + "b": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "extension_reqs": [] + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 2, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, "set": { "extension": "prelude", "name": "set",