From 6bd094f6dcf4607ef9f58423ed45d06604728b92 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 20 Nov 2024 15:20:46 +0000 Subject: [PATCH 1/2] feat: Emulate `TypeBound`s on parameters via constraints. (#1624) This PR translates some `TypeBound`s in `hugr-core` to the `nonlinear` constraint in `hugr-model`. This translation only occurs on parameters that take a runtime type directly. As a driveby change before the model stabilises, this PR also moves constraints out of the parameter lists into their own list. In its previous form this could have led to confusions about which parameter a local variable index refers to when a constraint is situated between two parameters in the list. We also remove constraints from aliases for now. Closes https://github.com/CQCL/hugr/issues/1637. --- hugr-core/src/export.rs | 85 +++++++++++--- hugr-core/src/import.rs | 107 ++++++++++++------ hugr-core/tests/model.rs | 7 ++ .../model__roundtrip_constraints.snap | 16 +++ hugr-model/capnp/hugr-v0.capnp | 34 +++--- hugr-model/src/v0/binary/read.rs | 40 +++---- hugr-model/src/v0/binary/write.rs | 27 ++--- hugr-model/src/v0/mod.rs | 54 ++++----- hugr-model/src/v0/text/hugr.pest | 16 +-- hugr-model/src/v0/text/parse.rs | 41 +++++-- hugr-model/src/v0/text/print.rs | 100 ++++++++-------- hugr-model/tests/binary.rs | 5 + .../tests/fixtures/model-constraints.edn | 13 +++ 13 files changed, 349 insertions(+), 196 deletions(-) create mode 100644 hugr-core/tests/snapshots/model__roundtrip_constraints.snap create mode 100644 hugr-model/tests/fixtures/model-constraints.edn diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 32460ce5d..093368b60 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -7,7 +7,7 @@ use crate::{ type_param::{TypeArgVariable, TypeParam}, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg, - TypeBase, TypeEnum, + TypeBase, TypeBound, TypeEnum, }, Direction, Hugr, HugrView, IncomingPort, Node, Port, }; @@ -44,8 +44,21 @@ struct Context<'a> { bump: &'a Bump, /// Stores the terms that we have already seen to avoid duplicates. term_map: FxHashMap, model::TermId>, + /// The current scope for local variables. + /// + /// This is set to the id of the smallest enclosing node that defines a polymorphic type. + /// We use this when exporting local variables in terms. local_scope: Option, + + /// Constraints to be added to the local scope. + /// + /// When exporting a node that defines a polymorphic type, we use this field + /// to collect the constraints that need to be added to that polymorphic + /// type. Currently this is used to record `nonlinear` constraints on uses + /// of `TypeParam::Type` with a `TypeBound::Copyable` bound. + local_constraints: Vec, + /// Mapping from extension operations to their declarations. decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>, } @@ -63,6 +76,7 @@ impl<'a> Context<'a> { term_map: FxHashMap::default(), local_scope: None, decl_operations: FxHashMap::default(), + local_constraints: Vec::new(), } } @@ -173,9 +187,11 @@ impl<'a> Context<'a> { } fn with_local_scope(&mut self, node: model::NodeId, f: impl FnOnce(&mut Self) -> T) -> T { - let old_scope = self.local_scope.replace(node); + let prev_local_scope = self.local_scope.replace(node); + let prev_local_constraints = std::mem::take(&mut self.local_constraints); let result = f(self); - self.local_scope = old_scope; + self.local_scope = prev_local_scope; + self.local_constraints = prev_local_constraints; result } @@ -232,10 +248,11 @@ impl<'a> Context<'a> { OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| { let name = this.get_func_name(node).unwrap(); - let (params, signature) = this.export_poly_func_type(&func.signature); + let (params, constraints, signature) = this.export_poly_func_type(&func.signature); let decl = this.bump.alloc(model::FuncDecl { name, params, + constraints, signature, }); let extensions = this.export_ext_set(&func.signature.body().extension_reqs); @@ -247,10 +264,11 @@ impl<'a> Context<'a> { OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| { let name = this.get_func_name(node).unwrap(); - let (params, func) = this.export_poly_func_type(&func.signature); + let (params, constraints, func) = this.export_poly_func_type(&func.signature); let decl = this.bump.alloc(model::FuncDecl { name, params, + constraints, signature: func, }); model::Operation::DeclareFunc { decl } @@ -450,10 +468,11 @@ impl<'a> Context<'a> { let decl = self.with_local_scope(node, |this| { let name = this.make_qualified_name(opdef.extension(), opdef.name()); - let (params, r#type) = this.export_poly_func_type(poly_func_type); + let (params, constraints, r#type) = this.export_poly_func_type(poly_func_type); let decl = this.bump.alloc(model::OperationDecl { name, params, + constraints, r#type, }); decl @@ -671,22 +690,36 @@ impl<'a> Context<'a> { regions.into_bump_slice() } + /// Exports a polymorphic function type. + /// + /// The returned triple consists of: + /// - The static parameters of the polymorphic function type. + /// - The constraints of the polymorphic function type. + /// - The function type itself. pub fn export_poly_func_type( &mut self, t: &PolyFuncTypeBase, - ) -> (&'a [model::Param<'a>], model::TermId) { + ) -> (&'a [model::Param<'a>], &'a [model::TermId], model::TermId) { let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump); + let scope = self + .local_scope + .expect("exporting poly func type outside of local scope"); for (i, param) in t.params().iter().enumerate() { let name = self.bump.alloc_str(&i.to_string()); - let r#type = self.export_type_param(param); - let param = model::Param::Implicit { name, r#type }; + let r#type = self.export_type_param(param, Some(model::LocalRef::Index(scope, i as _))); + let param = model::Param { + name, + r#type, + sort: model::ParamSort::Implicit, + }; params.push(param) } + let constraints = self.bump.alloc_slice_copy(&self.local_constraints); let body = self.export_func_type(t.body()); - (params.into_bump_slice(), body) + (params.into_bump_slice(), constraints, body) } pub fn export_type(&mut self, t: &TypeBase) -> model::TermId { @@ -703,7 +736,6 @@ impl<'a> Context<'a> { } TypeEnum::Function(func) => self.export_func_type(func), TypeEnum::Variable(index, _) => { - // This ignores the type bound for now let node = self.local_scope.expect("local variable out of scope"); self.make_term(model::Term::Var(model::LocalRef::Index(node, *index as _))) } @@ -794,20 +826,39 @@ impl<'a> Context<'a> { self.make_term(model::Term::List { items, tail: None }) } - pub fn export_type_param(&mut self, t: &TypeParam) -> model::TermId { + /// Exports a `TypeParam` to a term. + /// + /// The `var` argument is set when the type parameter being exported is the + /// type of a parameter to a polymorphic definition. In that case we can + /// generate a `nonlinear` constraint for the type of runtime types marked as + /// `TypeBound::Copyable`. + pub fn export_type_param( + &mut self, + t: &TypeParam, + var: Option>, + ) -> model::TermId { match t { - // This ignores the type bound for now. - TypeParam::Type { .. } => self.make_term(model::Term::Type), - // This ignores the type bound for now. + TypeParam::Type { b } => { + if let (Some(var), TypeBound::Copyable) = (var, b) { + let term = self.make_term(model::Term::Var(var)); + let non_linear = self.make_term(model::Term::NonLinearConstraint { term }); + self.local_constraints.push(non_linear); + } + + self.make_term(model::Term::Type) + } + // This ignores the bound on the natural for now. TypeParam::BoundedNat { .. } => self.make_term(model::Term::NatType), TypeParam::String => self.make_term(model::Term::StrType), TypeParam::List { param } => { - let item_type = self.export_type_param(param); + let item_type = self.export_type_param(param, None); self.make_term(model::Term::ListType { item_type }) } TypeParam::Tuple { params } => { let items = self.bump.alloc_slice_fill_iter( - params.iter().map(|param| self.export_type_param(param)), + params + .iter() + .map(|param| self.export_type_param(param, None)), ); let types = self.make_term(model::Term::List { items, tail: None }); self.make_term(model::Term::ApplyFull { diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index e6a53cb2f..7619ad44a 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -115,8 +115,8 @@ struct Context<'a> { /// A map from `NodeId` to the imported `Node`. nodes: FxHashMap, - /// The types of the local variables that are currently in scope. - local_variables: FxIndexMap<&'a str, model::TermId>, + /// The local variables that are currently in scope. + local_variables: FxIndexMap<&'a str, LocalVar>, custom_name_cache: FxHashMap<&'a str, (ExtensionId, SmolStr)>, } @@ -155,20 +155,20 @@ impl<'a> Context<'a> { .ok_or_else(|| model::ModelError::RegionNotFound(region_id).into()) } - /// Looks up a [`LocalRef`] within the current scope and returns its index and type. + /// Looks up a [`LocalRef`] within the current scope. fn resolve_local_ref( &self, local_ref: &model::LocalRef, - ) -> Result<(usize, model::TermId), ImportError> { + ) -> Result<(usize, LocalVar), ImportError> { let term = match local_ref { model::LocalRef::Index(_, index) => self .local_variables .get_index(*index as usize) - .map(|(_, term)| (*index as usize, *term)), + .map(|(_, v)| (*index as usize, *v)), model::LocalRef::Named(name) => self .local_variables .get_full(name) - .map(|(index, _, term)| (index, *term)), + .map(|(index, _, v)| (index, *v)), }; term.ok_or_else(|| model::ModelError::InvalidLocal(local_ref.to_string()).into()) @@ -898,41 +898,49 @@ impl<'a> Context<'a> { self.with_local_socpe(|ctx| { let mut imported_params = Vec::with_capacity(decl.params.len()); - for param in decl.params { - // TODO: `PolyFuncType` should be able to handle constraints - // and distinguish between implicit and explicit parameters. - match param { - model::Param::Implicit { name, r#type } => { - imported_params.push(ctx.import_type_param(*r#type)?); - ctx.local_variables.insert(name, *r#type); - } - model::Param::Explicit { name, r#type } => { - imported_params.push(ctx.import_type_param(*r#type)?); - ctx.local_variables.insert(name, *r#type); - } - model::Param::Constraint { constraint: _ } => { - return Err(error_unsupported!("constraints")); + ctx.local_variables.extend( + decl.params + .iter() + .map(|param| (param.name, LocalVar::new(param.r#type))), + ); + + for constraint in decl.constraints { + match ctx.get_term(*constraint)? { + model::Term::NonLinearConstraint { term } => { + let model::Term::Var(var) = ctx.get_term(*term)? else { + return Err(error_unsupported!( + "constraint on term that is not a variable" + )); + }; + + let var = ctx.resolve_local_ref(var)?.0; + ctx.local_variables[var].bound = TypeBound::Copyable; } + _ => return Err(error_unsupported!("constraint other than copy or discard")), } } + for (index, param) in decl.params.iter().enumerate() { + // NOTE: `PolyFuncType` only has explicit type parameters at present. + let bound = ctx.local_variables[index].bound; + imported_params.push(ctx.import_type_param(param.r#type, bound)?); + } + let body = ctx.import_func_type::(decl.signature)?; in_scope(ctx, PolyFuncTypeBase::new(imported_params, body)) }) } /// Import a [`TypeParam`] from a term that represents a static type. - fn import_type_param(&mut self, term_id: model::TermId) -> Result { + fn import_type_param( + &mut self, + term_id: model::TermId, + bound: TypeBound, + ) -> Result { match self.get_term(term_id)? { model::Term::Wildcard => Err(error_uninferred!("wildcard")), - model::Term::Type => { - // As part of the migration from `TypeBound`s to constraints, we pretend that all - // `TypeBound`s are copyable. - Ok(TypeParam::Type { - b: TypeBound::Copyable, - }) - } + model::Term::Type => Ok(TypeParam::Type { b: bound }), model::Term::StaticType => Err(error_unsupported!("`type` as `TypeParam`")), model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeParam`")), @@ -944,7 +952,9 @@ impl<'a> Context<'a> { model::Term::FuncType { .. } => Err(error_unsupported!("`(fn ...)` as `TypeParam`")), model::Term::ListType { item_type } => { - let param = Box::new(self.import_type_param(*item_type)?); + // At present `hugr-model` has no way to express that the item + // type of a list must be copyable. Therefore we import it as `Any`. + let param = Box::new(self.import_type_param(*item_type, TypeBound::Any)?); Ok(TypeParam::List { param }) } @@ -958,7 +968,10 @@ impl<'a> Context<'a> { | model::Term::List { .. } | model::Term::ExtSet { .. } | model::Term::Adt { .. } - | model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Control { .. } + | model::Term::NonLinearConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } model::Term::ControlType => { Err(error_unsupported!("type of control types as `TypeParam`")) @@ -966,7 +979,7 @@ impl<'a> Context<'a> { } } - /// Import a `TypeArg` froma term that represents a static type or value. + /// Import a `TypeArg` from a term that represents a static type or value. fn import_type_arg(&mut self, term_id: model::TermId) -> Result { match self.get_term(term_id)? { model::Term::Wildcard => Err(error_uninferred!("wildcard")), @@ -975,8 +988,8 @@ impl<'a> Context<'a> { } model::Term::Var(var) => { - let (index, var_type) = self.resolve_local_ref(var)?; - let decl = self.import_type_param(var_type)?; + let (index, var) = self.resolve_local_ref(var)?; + let decl = self.import_type_param(var.r#type, var.bound)?; Ok(TypeArg::new_var_use(index, decl)) } @@ -1014,7 +1027,10 @@ impl<'a> Context<'a> { model::Term::FuncType { .. } | model::Term::Adt { .. } - | model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Control { .. } + | model::Term::NonLinearConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } } } @@ -1115,7 +1131,10 @@ impl<'a> Context<'a> { | model::Term::List { .. } | model::Term::Control { .. } | model::Term::ControlType - | model::Term::Nat(_) => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Nat(_) + | model::Term::NonLinearConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } } } @@ -1291,3 +1310,21 @@ impl<'a> Names<'a> { Ok(Self { items }) } } + +/// Information about a local variable. +#[derive(Debug, Clone, Copy)] +struct LocalVar { + /// The type of the variable. + r#type: model::TermId, + /// The type bound of the variable. + bound: TypeBound, +} + +impl LocalVar { + pub fn new(r#type: model::TermId) -> Self { + Self { + r#type, + bound: TypeBound::Any, + } + } +} diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs index 611eda660..d9ef0d2c9 100644 --- a/hugr-core/tests/model.rs +++ b/hugr-core/tests/model.rs @@ -58,3 +58,10 @@ pub fn test_roundtrip_params() { "../../hugr-model/tests/fixtures/model-params.edn" ))); } + +#[test] +pub fn test_roundtrip_constraints() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-constraints.edn" + ))); +} diff --git a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap new file mode 100644 index 000000000..f085c4785 --- /dev/null +++ b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap @@ -0,0 +1,16 @@ +--- +source: hugr-core/tests/model.rs +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-constraints.edn\"))" +--- +(hugr 0) + +(declare-func array.replicate + (forall ?0 type) + (forall ?1 nat) + (where (nonlinear ?0)) + [?0] [(@ array.Array ?0 ?1)] (ext)) + +(declare-func array.copy + (forall ?0 type) + (where (nonlinear ?0)) + [(@ array.Array ?0)] [(@ array.Array ?0) (@ array.Array ?0)] (ext)) diff --git a/hugr-model/capnp/hugr-v0.capnp b/hugr-model/capnp/hugr-v0.capnp index 95db81205..94341beba 100644 --- a/hugr-model/capnp/hugr-v0.capnp +++ b/hugr-model/capnp/hugr-v0.capnp @@ -56,13 +56,15 @@ struct Operation { struct FuncDefn { name @0 :Text; params @1 :List(Param); - signature @2 :TermId; + constraints @2 :List(TermId); + signature @3 :TermId; } struct FuncDecl { name @0 :Text; params @1 :List(Param); - signature @2 :TermId; + constraints @2 :List(TermId); + signature @3 :TermId; } struct AliasDefn { @@ -81,13 +83,15 @@ struct Operation { struct ConstructorDecl { name @0 :Text; params @1 :List(Param); - type @2 :TermId; + constraints @2 :List(TermId); + type @3 :TermId; } struct OperationDecl { name @0 :Text; params @1 :List(Param); - type @2 :TermId; + constraints @2 :List(TermId); + type @3 :TermId; } } @@ -157,6 +161,7 @@ struct Term { funcType @17 :FuncType; control @18 :TermId; controlType @19 :Void; + nonLinearConstraint @20 :TermId; } struct Apply { @@ -187,19 +192,12 @@ struct Term { } struct Param { - union { - implicit @0 :Implicit; - explicit @1 :Explicit; - constraint @2 :TermId; - } - - struct Implicit { - name @0 :Text; - type @1 :TermId; - } + name @0 :Text; + type @1 :TermId; + sort @2 :ParamSort; +} - struct Explicit { - name @0 :Text; - type @1 :TermId; - } +enum ParamSort { + implicit @0; + explicit @1; } diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index 681bd4ea9..5381a7dc8 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -140,10 +140,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let signature = model::TermId(reader.get_signature()); let decl = bump.alloc(model::FuncDecl { name, params, + constraints, signature, }); model::Operation::DefineFunc { decl } @@ -152,10 +154,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let signature = model::TermId(reader.get_signature()); let decl = bump.alloc(model::FuncDecl { name, params, + constraints, signature, }); model::Operation::DeclareFunc { decl } @@ -189,10 +193,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let r#type = model::TermId(reader.get_type()); let decl = bump.alloc(model::ConstructorDecl { name, params, + constraints, r#type, }); model::Operation::DeclareConstructor { decl } @@ -201,10 +207,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let r#type = model::TermId(reader.get_type()); let decl = bump.alloc(model::OperationDecl { name, params, + constraints, r#type, }); model::Operation::DeclareOperation { decl } @@ -332,6 +340,10 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult Which::Control(values) => model::Term::Control { values: model::TermId(values), }, + + Which::NonLinearConstraint(term) => model::Term::NonLinearConstraint { + term: model::TermId(term), + }, }) } @@ -348,23 +360,13 @@ fn read_param<'a>( bump: &'a Bump, reader: hugr_capnp::param::Reader, ) -> ReadResult> { - use hugr_capnp::param::Which; - Ok(match reader.which()? { - Which::Implicit(reader) => { - let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); - let r#type = model::TermId(reader.get_type()); - model::Param::Implicit { name, r#type } - } - Which::Explicit(reader) => { - let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); - let r#type = model::TermId(reader.get_type()); - model::Param::Explicit { name, r#type } - } - Which::Constraint(constraint) => { - let constraint = model::TermId(constraint); - model::Param::Constraint { constraint } - } - }) + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let r#type = model::TermId(reader.get_type()); + + let sort = match reader.get_sort()? { + hugr_capnp::ParamSort::Implicit => model::ParamSort::Implicit, + hugr_capnp::ParamSort::Explicit => model::ParamSort::Explicit, + }; + + Ok(model::Param { name, r#type, sort }) } diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index a4b64d646..f3a0a14d2 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -60,12 +60,14 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode let mut builder = builder.init_func_defn(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_signature(decl.signature.0); } model::Operation::DeclareFunc { decl } => { let mut builder = builder.init_func_decl(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_signature(decl.signature.0); } @@ -87,12 +89,14 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode let mut builder = builder.init_constructor_decl(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_type(decl.r#type.0); } model::Operation::DeclareOperation { decl } => { let mut builder = builder.init_operation_decl(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_type(decl.r#type.0); } @@ -101,19 +105,12 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode } fn write_param(mut builder: hugr_capnp::param::Builder, param: &model::Param) { - match param { - model::Param::Implicit { name, r#type } => { - let mut builder = builder.init_implicit(); - builder.set_name(name); - builder.set_type(r#type.0); - } - model::Param::Explicit { name, r#type } => { - let mut builder = builder.init_explicit(); - builder.set_name(name); - builder.set_type(r#type.0); - } - model::Param::Constraint { constraint } => builder.set_constraint(constraint.0), - } + builder.set_name(param.name); + builder.set_type(param.r#type.0); + builder.set_sort(match param.sort { + model::ParamSort::Implicit => hugr_capnp::ParamSort::Implicit, + model::ParamSort::Explicit => hugr_capnp::ParamSort::Explicit, + }); } fn write_global_ref(mut builder: hugr_capnp::global_ref::Builder, global_ref: &model::GlobalRef) { @@ -212,5 +209,9 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { builder.set_outputs(outputs.0); builder.set_extensions(extensions.0); } + + model::Term::NonLinearConstraint { term } => { + builder.set_non_linear_constraint(term.0); + } } } diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index cb8713b32..16c7cb6c6 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -397,6 +397,8 @@ pub struct FuncDecl<'a> { pub name: &'a str, /// The static parameters of the function. pub params: &'a [Param<'a>], + /// The constraints on the static parameters. + pub constraints: &'a [TermId], /// The signature of the function. pub signature: TermId, } @@ -419,6 +421,8 @@ pub struct ConstructorDecl<'a> { pub name: &'a str, /// The static parameters of the constructor. pub params: &'a [Param<'a>], + /// The constraints on the static parameters. + pub constraints: &'a [TermId], /// The type of the constructed term. pub r#type: TermId, } @@ -430,6 +434,8 @@ pub struct OperationDecl<'a> { pub name: &'a str, /// The static parameters of the operation. pub params: &'a [Param<'a>], + /// The constraints on the static parameters. + pub constraints: &'a [TermId], /// The type of the operation. This must be a function type. pub r#type: TermId, } @@ -662,6 +668,12 @@ pub enum Term<'a> { /// /// `ctrl : static` ControlType, + + /// Constraint that requires a runtime type to be copyable and discardable. + NonLinearConstraint { + /// The runtime type that must be copyable and discardable. + term: TermId, + }, } /// A parameter to a function or alias. @@ -669,33 +681,23 @@ pub enum Term<'a> { /// Parameter names must be unique within a parameter list. /// Implicit and explicit parameters share a namespace. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum Param<'a> { - /// An implicit parameter that should be inferred, unless a full application form is used +pub struct Param<'a> { + /// The name of the parameter. + pub name: &'a str, + /// The type of the parameter. + pub r#type: TermId, + /// The sort of the parameter (implicit or explicit). + pub sort: ParamSort, +} + +/// The sort of a parameter (implicit or explicit). +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ParamSort { + /// The parameter is implicit and should be inferred, unless a full application form is used /// (see [`Term::ApplyFull`] and [`Operation::CustomFull`]). - Implicit { - /// The name of the parameter. - name: &'a str, - /// The type of the parameter. - /// - /// This must be a term of type `static`. - r#type: TermId, - }, - /// An explicit parameter that should always be provided. - Explicit { - /// The name of the parameter. - name: &'a str, - /// The type of the parameter. - /// - /// This must be a term of type `static`. - r#type: TermId, - }, - /// A constraint that should be satisfied by other parameters in a parameter list. - Constraint { - /// The constraint to be satisfied. - /// - /// This must be a term of type `constraint`. - constraint: TermId, - }, + Implicit, + /// The parameter is explicit and should always be provided. + Explicit, } /// Errors that can occur when traversing and interpreting the model. diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index 132d78567..d05e3d774 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -56,16 +56,16 @@ node_tag = { "(" ~ "tag" ~ tag ~ port_lists? ~ signature? ~ meta* node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } signature = { "(" ~ "signature" ~ term ~ ")" } -func_header = { symbol ~ param* ~ term ~ term ~ term } +func_header = { symbol ~ param* ~ where_clause* ~ term ~ term ~ term } alias_header = { symbol ~ param* ~ term } -ctr_header = { symbol ~ param* ~ term } -operation_header = { symbol ~ param* ~ term } +ctr_header = { symbol ~ param* ~ where_clause* ~ term } +operation_header = { symbol ~ param* ~ where_clause* ~ term } -param = { param_implicit | param_explicit | param_constraint } +param = { param_implicit | param_explicit } -param_implicit = { "(" ~ "forall" ~ term_var ~ term ~ ")" } -param_explicit = { "(" ~ "param" ~ term_var ~ term ~ ")" } -param_constraint = { "(" ~ "where" ~ term ~ ")" } +param_implicit = { "(" ~ "forall" ~ term_var ~ term ~ ")" } +param_explicit = { "(" ~ "param" ~ term_var ~ term ~ ")" } +where_clause = { "(" ~ "where" ~ term ~ ")" } region = { region_dfg | region_cfg } region_dfg = { "(" ~ "dfg" ~ port_lists? ~ signature? ~ meta* ~ node* ~ ")" } @@ -92,6 +92,7 @@ term = { | term_ctrl_type | term_apply_full | term_apply + | term_non_linear } term_wildcard = { "_" } @@ -114,3 +115,4 @@ term_adt = { "(" ~ "adt" ~ term ~ ")" } term_func_type = { "(" ~ "fn" ~ term ~ term ~ term ~ ")" } term_ctrl = { "(" ~ "ctrl" ~ term ~ ")" } term_ctrl_type = { "ctrl" } +term_non_linear = { "(" ~ "nonlinear" ~ term ~ ")" } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index fa486454b..370dbeac0 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -7,7 +7,7 @@ use thiserror::Error; use crate::v0::{ AliasDecl, ConstructorDecl, FuncDecl, GlobalRef, LinkRef, LocalRef, MetaItem, Module, Node, - NodeId, Operation, OperationDecl, Param, Region, RegionId, RegionKind, Term, TermId, + NodeId, Operation, OperationDecl, Param, ParamSort, Region, RegionId, RegionKind, Term, TermId, }; mod pest_parser { @@ -209,6 +209,11 @@ impl<'a> ParseContext<'a> { Term::Control { values } } + Rule::term_non_linear => { + let term = self.parse_term(inner.next().unwrap())?; + Term::NonLinearConstraint { term } + } + r => unreachable!("term: {:?}", r), }; @@ -544,6 +549,7 @@ impl<'a> ParseContext<'a> { let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; + let constraints = self.parse_constraints(&mut inner)?; let inputs = self.parse_term(inner.next().unwrap())?; let outputs = self.parse_term(inner.next().unwrap())?; @@ -559,6 +565,7 @@ impl<'a> ParseContext<'a> { Ok(self.bump.alloc(FuncDecl { name, params, + constraints, signature: func, })) } @@ -584,11 +591,13 @@ impl<'a> ParseContext<'a> { let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; + let constraints = self.parse_constraints(&mut inner)?; let r#type = self.parse_term(inner.next().unwrap())?; Ok(self.bump.alloc(ConstructorDecl { name, params, + constraints, r#type, })) } @@ -599,11 +608,13 @@ impl<'a> ParseContext<'a> { let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; + let constraints = self.parse_constraints(&mut inner)?; let r#type = self.parse_term(inner.next().unwrap())?; Ok(self.bump.alloc(OperationDecl { name, params, + constraints, r#type, })) } @@ -619,18 +630,21 @@ impl<'a> ParseContext<'a> { let mut inner = param.into_inner(); let name = &inner.next().unwrap().as_str()[1..]; let r#type = self.parse_term(inner.next().unwrap())?; - Param::Implicit { name, r#type } + Param { + name, + r#type, + sort: ParamSort::Implicit, + } } Rule::param_explicit => { let mut inner = param.into_inner(); let name = &inner.next().unwrap().as_str()[1..]; let r#type = self.parse_term(inner.next().unwrap())?; - Param::Explicit { name, r#type } - } - Rule::param_constraint => { - let mut inner = param.into_inner(); - let constraint = self.parse_term(inner.next().unwrap())?; - Param::Constraint { constraint } + Param { + name, + r#type, + sort: ParamSort::Explicit, + } } _ => unreachable!(), }; @@ -641,6 +655,17 @@ impl<'a> ParseContext<'a> { Ok(self.bump.alloc_slice_copy(¶ms)) } + fn parse_constraints(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [TermId]> { + let mut constraints = Vec::new(); + + for pair in filter_rule(pairs, Rule::where_clause) { + let constraint = self.parse_term(pair.into_inner().next().unwrap())?; + constraints.push(constraint); + } + + Ok(self.bump.alloc_slice_copy(&constraints)) + } + fn parse_signature(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult> { let Some(Rule::signature) = pairs.peek().map(|p| p.as_rule()) else { return Ok(None); diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index 01b9d7195..512f6d1e4 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -2,8 +2,8 @@ use pretty::{Arena, DocAllocator, RefDoc}; use std::borrow::Cow; use crate::v0::{ - GlobalRef, LinkRef, LocalRef, MetaItem, ModelError, Module, NodeId, Operation, Param, RegionId, - RegionKind, Term, TermId, + GlobalRef, LinkRef, LocalRef, MetaItem, ModelError, Module, NodeId, Operation, Param, + ParamSort, RegionId, RegionKind, Term, TermId, }; type PrintError = ModelError; @@ -122,15 +122,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { f: impl FnOnce(&mut Self) -> PrintResult, ) -> PrintResult { let locals = std::mem::take(&mut self.locals); - - for param in params { - match param { - Param::Implicit { name, .. } => self.locals.push(name), - Param::Explicit { name, .. } => self.locals.push(name), - Param::Constraint { .. } => {} - } - } - + self.locals.extend(params.iter().map(|param| param.name)); let result = f(self); self.locals = locals; result @@ -178,9 +170,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; match self.module.get_term(decl.signature) { Some(Term::FuncType { @@ -208,9 +199,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; match self.module.get_term(decl.signature) { Some(Term::FuncType { @@ -303,9 +293,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; this.print_term(decl.r#type)?; this.print_term(*value)?; @@ -318,9 +306,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; this.print_term(decl.r#type)?; this.print_meta(node_data.meta)?; @@ -333,9 +319,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; this.print_term(decl.r#type)?; this.print_meta(node_data.meta)?; @@ -348,9 +333,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; this.print_term(decl.r#type)?; this.print_meta(node_data.meta)?; @@ -384,10 +368,9 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { } fn print_regions(&mut self, regions: &'a [RegionId]) -> PrintResult<()> { - for region in regions { - self.print_region(*region)?; - } - Ok(()) + regions + .iter() + .try_for_each(|region| self.print_region(*region)) } fn print_region(&mut self, region: RegionId) -> PrintResult<()> { @@ -422,11 +405,10 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { .get_region(region) .ok_or(PrintError::RegionNotFound(region))?; - for node_id in region_data.children { - self.print_node(*node_id)?; - } - - Ok(()) + region_data + .children + .iter() + .try_for_each(|node_id| self.print_node(*node_id)) } fn print_port_lists( @@ -460,25 +442,33 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { } } + fn print_params(&mut self, params: &'a [Param<'a>]) -> PrintResult<()> { + params.iter().try_for_each(|param| self.print_param(*param)) + } + fn print_param(&mut self, param: Param<'a>) -> PrintResult<()> { - self.print_parens(|this| match param { - Param::Implicit { name, r#type } => { - this.print_text("forall"); - this.print_text(format!("?{}", name)); - this.print_term(r#type) - } - Param::Explicit { name, r#type } => { - this.print_text("param"); - this.print_text(format!("?{}", name)); - this.print_term(r#type) - } - Param::Constraint { constraint } => { - this.print_text("where"); - this.print_term(constraint) - } + self.print_parens(|this| { + match param.sort { + ParamSort::Implicit => this.print_text("forall"), + ParamSort::Explicit => this.print_text("param"), + }; + + this.print_text(format!("?{}", param.name)); + this.print_term(param.r#type) }) } + fn print_constraints(&mut self, terms: &'a [TermId]) -> PrintResult<()> { + for term in terms { + self.print_parens(|this| { + this.print_text("where"); + this.print_term(*term) + })?; + } + + Ok(()) + } + fn print_term(&mut self, term_id: TermId) -> PrintResult<()> { let term_data = self .module @@ -598,6 +588,10 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { self.print_text("ctrl"); Ok(()) } + Term::NonLinearConstraint { term } => self.print_parens(|this| { + this.print_text("nonlinear"); + this.print_term(*term) + }), } } diff --git a/hugr-model/tests/binary.rs b/hugr-model/tests/binary.rs index 043061677..93955fe6e 100644 --- a/hugr-model/tests/binary.rs +++ b/hugr-model/tests/binary.rs @@ -51,3 +51,8 @@ pub fn test_params() { pub fn test_decl_exts() { binary_roundtrip(include_str!("fixtures/model-decl-exts.edn")); } + +#[test] +pub fn test_constraints() { + binary_roundtrip(include_str!("fixtures/model-constraints.edn")); +} diff --git a/hugr-model/tests/fixtures/model-constraints.edn b/hugr-model/tests/fixtures/model-constraints.edn new file mode 100644 index 000000000..5db6b9886 --- /dev/null +++ b/hugr-model/tests/fixtures/model-constraints.edn @@ -0,0 +1,13 @@ +(hugr 0) + +(declare-func array.replicate + (forall ?t type) + (forall ?n nat) + (where (nonlinear ?t)) + [?t] [(@ array.Array ?t ?n)] + (ext)) + +(declare-func array.copy + (forall ?t type) + (where (nonlinear ?t)) + [(@ array.Array ?t)] [(@ array.Array ?t) (@ array.Array ?t)] (ext)) From 649589c9e3f1fbd9cfff53a2adb8e1f9649fbe87 Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Thu, 21 Nov 2024 10:09:04 +0000 Subject: [PATCH 2/2] feat: Add array `repeat` and `scan` ops (#1633) Closes #1627 --- hugr-core/src/extension/prelude.rs | 1 + hugr-core/src/extension/prelude/array.rs | 293 ++++++++++++++++++- hugr-py/src/hugr/std/_json_defs/prelude.json | 183 ++++++++++++ specification/std_extensions/prelude.json | 183 ++++++++++++ 4 files changed, 659 insertions(+), 1 deletion(-) diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index ca338eae3..8e0eb0f4b 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -101,6 +101,7 @@ lazy_static! { NoopDef.add_to_extension(&mut prelude).unwrap(); LiftDef.add_to_extension(&mut prelude).unwrap(); array::ArrayOpDef::load_all_ops(&mut prelude).unwrap(); + array::ArrayScanDef.add_to_extension(&mut prelude).unwrap(); prelude }; /// An extension registry containing only the prelude diff --git a/hugr-core/src/extension/prelude/array.rs b/hugr-core/src/extension/prelude/array.rs index a15bf23cc..6013039d4 100644 --- a/hugr-core/src/extension/prelude/array.rs +++ b/hugr-core/src/extension/prelude/array.rs @@ -1,3 +1,6 @@ +use std::str::FromStr; + +use itertools::Itertools; use strum_macros::EnumIter; use strum_macros::EnumString; use strum_macros::IntoStaticStr; @@ -17,8 +20,10 @@ use crate::ops::ExtensionOp; use crate::ops::NamedOp; use crate::ops::OpName; use crate::type_row; +use crate::types::FuncTypeBase; use crate::types::FuncValueType; +use crate::types::RowVariable; use crate::types::TypeBound; use crate::types::Type; @@ -28,6 +33,7 @@ use crate::extension::SignatureError; use crate::types::PolyFuncTypeRV; use crate::types::type_param::TypeArg; +use crate::types::TypeRV; use crate::Extension; use super::PRELUDE_ID; @@ -46,6 +52,7 @@ pub enum ArrayOpDef { pop_left, pop_right, discard_empty, + repeat, } /// Static parameters for array operations. Includes array size. Type is part of the type scheme. @@ -118,6 +125,14 @@ impl ArrayOpDef { let standard_params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; match self { + repeat => { + let func = + Type::new_function(FuncValueType::new(type_row![], elem_ty_var.clone())); + PolyFuncTypeRV::new( + standard_params, + FuncValueType::new(vec![func], array_ty.clone()), + ) + } get => { let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; let copy_elem_ty = Type::new_var_use(1, TypeBound::Copyable); @@ -179,6 +194,10 @@ impl MakeOpDef for ArrayOpDef { fn description(&self) -> String { match self { ArrayOpDef::new_array => "Create a new array from elements", + ArrayOpDef::repeat => { + "Creates a new array whose elements are initialised by calling \ + the given function n times" + } ArrayOpDef::get => "Get an element from an array", ArrayOpDef::set => "Set an element in an array", ArrayOpDef::swap => "Swap two elements in an array", @@ -246,7 +265,7 @@ impl MakeExtensionOp for ArrayOp { ); vec![ty_arg] } - new_array | pop_left | pop_right | get | set | swap => { + new_array | repeat | pop_left | pop_right | get | set | swap => { vec![TypeArg::BoundedNat { n: self.size }, ty_arg] } } @@ -312,6 +331,192 @@ pub fn new_array_op(element_ty: Type, size: u64) -> ExtensionOp { op.to_extension_op().unwrap() } +/// Name of the operation for the combined map/fold operation +pub const ARRAY_SCAN_OP_ID: OpName = OpName::new_inline("scan"); + +/// Definition of the array scan op. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +pub struct ArrayScanDef; + +impl NamedOp for ArrayScanDef { + fn name(&self) -> OpName { + ARRAY_SCAN_OP_ID + } +} + +impl FromStr for ArrayScanDef { + type Err = (); + + fn from_str(s: &str) -> Result { + if s == ArrayScanDef.name() { + Ok(Self) + } else { + Err(()) + } + } +} + +impl ArrayScanDef { + /// To avoid recursion when defining the extension, take the type definition as an argument. + fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { + // array, (T1, *A -> T2, *A), -> array, *A + let params = vec![ + TypeParam::max_nat(), + TypeBound::Any.into(), + TypeBound::Any.into(), + TypeParam::new_list(TypeBound::Any), + ]; + let n = TypeArg::new_var_use(0, TypeParam::max_nat()); + let t1 = Type::new_var_use(1, TypeBound::Any); + let t2 = Type::new_var_use(2, TypeBound::Any); + let s = TypeRV::new_row_var_use(3, TypeBound::Any); + PolyFuncTypeRV::new( + params, + FuncTypeBase::::new( + vec![ + instantiate(array_def, n.clone(), t1.clone()).into(), + Type::new_function(FuncTypeBase::::new( + vec![t1.into(), s.clone()], + vec![t2.clone().into(), s.clone()], + )) + .into(), + s.clone(), + ], + vec![instantiate(array_def, n, t2).into(), s], + ), + ) + .into() + } +} + +impl MakeOpDef for ArrayScanDef { + fn from_def(op_def: &OpDef) -> Result + where + Self: Sized, + { + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + } + + fn signature(&self) -> SignatureFunc { + self.signature_from_def(array_type_def()) + } + + fn extension(&self) -> ExtensionId { + PRELUDE_ID + } + + fn description(&self) -> String { + "A combination of map and foldl. Applies a function to each element \ + of the array with an accumulator that is passed through from start to \ + finish. Returns the resulting array and the final state of the \ + accumulator." + .into() + } + + /// Add an operation implemented as a [MakeOpDef], which can provide the data + /// required to define an [OpDef], to an extension. + // + // This method is re-defined here since we need to pass the array type def while + // computing the signature, to avoid recursive loops initializing the extension. + fn add_to_extension( + &self, + extension: &mut Extension, + ) -> Result<(), crate::extension::ExtensionBuildError> { + let sig = self.signature_from_def(extension.get_type(ARRAY_TYPE_NAME).unwrap()); + let def = extension.add_op(self.name(), self.description(), sig)?; + + self.post_opdef(def); + + Ok(()) + } +} + +/// Definition of the array scan op. +#[derive(Clone, Debug, PartialEq)] +pub struct ArrayScan { + /// The element type of the input array. + src_ty: Type, + /// The target element type of the output array. + tgt_ty: Type, + /// The accumulator types. + acc_tys: Vec, + /// Size of the array. + size: u64, +} + +impl ArrayScan { + fn new(src_ty: Type, tgt_ty: Type, acc_tys: Vec, size: u64) -> Self { + ArrayScan { + src_ty, + tgt_ty, + acc_tys, + size, + } + } +} + +impl NamedOp for ArrayScan { + fn name(&self) -> OpName { + ARRAY_SCAN_OP_ID + } +} + +impl MakeExtensionOp for ArrayScan { + fn from_extension_op(ext_op: &ExtensionOp) -> Result + where + Self: Sized, + { + let def = ArrayScanDef::from_def(ext_op.def())?; + def.instantiate(ext_op.args()) + } + + fn type_args(&self) -> Vec { + vec![ + TypeArg::BoundedNat { n: self.size }, + self.src_ty.clone().into(), + self.tgt_ty.clone().into(), + TypeArg::Sequence { + elems: self.acc_tys.clone().into_iter().map_into().collect(), + }, + ] + } +} + +impl MakeRegisteredOp for ArrayScan { + fn extension_id(&self) -> ExtensionId { + PRELUDE_ID + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r crate::extension::ExtensionRegistry { + &PRELUDE_REGISTRY + } +} + +impl HasDef for ArrayScan { + type Def = ArrayScanDef; +} + +impl HasConcrete for ArrayScanDef { + type Concrete = ArrayScan; + + fn instantiate(&self, type_args: &[TypeArg]) -> Result { + match type_args { + [TypeArg::BoundedNat { n }, TypeArg::Type { ty: src_ty }, TypeArg::Type { ty: tgt_ty }, TypeArg::Sequence { elems: acc_tys }] => + { + let acc_tys: Result<_, OpLoadError> = acc_tys + .iter() + .map(|acc_ty| match acc_ty { + TypeArg::Type { ty } => Ok(ty.clone()), + _ => Err(SignatureError::InvalidTypeArgs.into()), + }) + .collect(); + Ok(ArrayScan::new(src_ty.clone(), tgt_ty.clone(), acc_tys?, *n)) + } + _ => Err(SignatureError::InvalidTypeArgs.into()), + } + } +} + #[cfg(test)] mod tests { use strum::IntoEnumIterator; @@ -320,6 +525,7 @@ mod tests { builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr}, extension::prelude::{BOOL_T, QB_T}, ops::{OpTrait, OpType}, + types::Signature, }; use super::*; @@ -459,4 +665,89 @@ mod tests { ) ); } + + #[test] + fn test_repeat() { + let size = 2; + let element_ty = QB_T; + let op = ArrayOpDef::repeat.to_concrete(element_ty.clone(), size); + + let optype: OpType = op.into(); + + let sig = optype.dataflow_signature().unwrap(); + + assert_eq!( + sig.io(), + ( + &vec![Type::new_function(Signature::new(vec![], vec![QB_T]))].into(), + &vec![array_type(size, element_ty.clone())].into(), + ) + ); + } + + #[test] + fn test_scan_def() { + let op = ArrayScan::new(BOOL_T, QB_T, vec![USIZE_T], 2); + let optype: OpType = op.clone().into(); + let new_op: ArrayScan = optype.cast().unwrap(); + assert_eq!(new_op, op); + } + + #[test] + fn test_scan_map() { + let size = 2; + let src_ty = QB_T; + let tgt_ty = BOOL_T; + + let op = ArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size); + let optype: OpType = op.into(); + let sig = optype.dataflow_signature().unwrap(); + + assert_eq!( + sig.io(), + ( + &vec![ + array_type(size, src_ty.clone()), + Type::new_function(Signature::new(vec![src_ty], vec![tgt_ty.clone()])) + ] + .into(), + &vec![array_type(size, tgt_ty)].into(), + ) + ); + } + + #[test] + fn test_scan_accs() { + let size = 2; + let src_ty = QB_T; + let tgt_ty = BOOL_T; + let acc_ty1 = USIZE_T; + let acc_ty2 = QB_T; + + let op = ArrayScan::new( + src_ty.clone(), + tgt_ty.clone(), + vec![acc_ty1.clone(), acc_ty2.clone()], + size, + ); + let optype: OpType = op.into(); + let sig = optype.dataflow_signature().unwrap(); + + assert_eq!( + sig.io(), + ( + &vec![ + array_type(size, src_ty.clone()), + Type::new_function(Signature::new( + vec![src_ty, acc_ty1.clone(), acc_ty2.clone()], + vec![tgt_ty.clone(), acc_ty1.clone(), acc_ty2.clone()] + )), + acc_ty1.clone(), + acc_ty2.clone() + ] + .into(), + &vec![array_type(size, tgt_ty), acc_ty1, acc_ty2].into(), + ) + ); + } } diff --git a/hugr-py/src/hugr/std/_json_defs/prelude.json b/hugr-py/src/hugr/std/_json_defs/prelude.json index 014ba3ede..b48692b39 100644 --- a/hugr-py/src/hugr/std/_json_defs/prelude.json +++ b/hugr-py/src/hugr/std/_json_defs/prelude.json @@ -418,6 +418,189 @@ }, "binary": false }, + "repeat": { + "extension": "prelude", + "name": "repeat", + "description": "Creates a new array whose elements are initialised by calling the given function n times", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "G", + "input": [], + "output": [ + { + "t": "V", + "i": 1, + "b": "A" + } + ], + "extension_reqs": [] + } + ], + "output": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, + "scan": { + "extension": "prelude", + "name": "scan", + "description": "A combination of map and foldl. Applies a function to each element of the array with an accumulator that is passed through from start to finish. Returns the resulting array and the final state of the accumulator.", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + }, + { + "tp": "Type", + "b": "A" + }, + { + "tp": "List", + "param": { + "tp": "Type", + "b": "A" + } + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "G", + "input": [ + { + "t": "V", + "i": 1, + "b": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "output": [ + { + "t": "V", + "i": 2, + "b": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "extension_reqs": [] + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 2, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, "set": { "extension": "prelude", "name": "set", diff --git a/specification/std_extensions/prelude.json b/specification/std_extensions/prelude.json index 014ba3ede..b48692b39 100644 --- a/specification/std_extensions/prelude.json +++ b/specification/std_extensions/prelude.json @@ -418,6 +418,189 @@ }, "binary": false }, + "repeat": { + "extension": "prelude", + "name": "repeat", + "description": "Creates a new array whose elements are initialised by calling the given function n times", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "G", + "input": [], + "output": [ + { + "t": "V", + "i": 1, + "b": "A" + } + ], + "extension_reqs": [] + } + ], + "output": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, + "scan": { + "extension": "prelude", + "name": "scan", + "description": "A combination of map and foldl. Applies a function to each element of the array with an accumulator that is passed through from start to finish. Returns the resulting array and the final state of the accumulator.", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + }, + { + "tp": "Type", + "b": "A" + }, + { + "tp": "List", + "param": { + "tp": "Type", + "b": "A" + } + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "G", + "input": [ + { + "t": "V", + "i": 1, + "b": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "output": [ + { + "t": "V", + "i": 2, + "b": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "extension_reqs": [] + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 2, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, "set": { "extension": "prelude", "name": "set",