Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into HEAD
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Nov 22, 2024
2 parents c076a2c + 649589c commit efdf516
Show file tree
Hide file tree
Showing 17 changed files with 1,008 additions and 197 deletions.
85 changes: 68 additions & 17 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
type_param::{TypeArgVariable, TypeParam},
type_row::TypeRowBase,
CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg,
TypeBase, TypeEnum,
TypeBase, TypeBound, TypeEnum,
},
Direction, Hugr, HugrView, IncomingPort, Node, Port,
};
Expand Down Expand Up @@ -44,8 +44,21 @@ struct Context<'a> {
bump: &'a Bump,
/// Stores the terms that we have already seen to avoid duplicates.
term_map: FxHashMap<model::Term<'a>, model::TermId>,

/// The current scope for local variables.
///
/// This is set to the id of the smallest enclosing node that defines a polymorphic type.
/// We use this when exporting local variables in terms.
local_scope: Option<model::NodeId>,

/// Constraints to be added to the local scope.
///
/// When exporting a node that defines a polymorphic type, we use this field
/// to collect the constraints that need to be added to that polymorphic
/// type. Currently this is used to record `nonlinear` constraints on uses
/// of `TypeParam::Type` with a `TypeBound::Copyable` bound.
local_constraints: Vec<model::TermId>,

/// Mapping from extension operations to their declarations.
decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>,
}
Expand All @@ -63,6 +76,7 @@ impl<'a> Context<'a> {
term_map: FxHashMap::default(),
local_scope: None,
decl_operations: FxHashMap::default(),
local_constraints: Vec::new(),
}
}

Expand Down Expand Up @@ -173,9 +187,11 @@ impl<'a> Context<'a> {
}

fn with_local_scope<T>(&mut self, node: model::NodeId, f: impl FnOnce(&mut Self) -> T) -> T {
let old_scope = self.local_scope.replace(node);
let prev_local_scope = self.local_scope.replace(node);
let prev_local_constraints = std::mem::take(&mut self.local_constraints);
let result = f(self);
self.local_scope = old_scope;
self.local_scope = prev_local_scope;
self.local_constraints = prev_local_constraints;
result
}

Expand Down Expand Up @@ -232,10 +248,11 @@ impl<'a> Context<'a> {

OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| {
let name = this.get_func_name(node).unwrap();
let (params, signature) = this.export_poly_func_type(&func.signature);
let (params, constraints, signature) = this.export_poly_func_type(&func.signature);
let decl = this.bump.alloc(model::FuncDecl {
name,
params,
constraints,
signature,
});
let extensions = this.export_ext_set(&func.signature.body().extension_reqs);
Expand All @@ -247,10 +264,11 @@ impl<'a> Context<'a> {

OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| {
let name = this.get_func_name(node).unwrap();
let (params, func) = this.export_poly_func_type(&func.signature);
let (params, constraints, func) = this.export_poly_func_type(&func.signature);
let decl = this.bump.alloc(model::FuncDecl {
name,
params,
constraints,
signature: func,
});
model::Operation::DeclareFunc { decl }
Expand Down Expand Up @@ -450,10 +468,11 @@ impl<'a> Context<'a> {

let decl = self.with_local_scope(node, |this| {
let name = this.make_qualified_name(opdef.extension(), opdef.name());
let (params, r#type) = this.export_poly_func_type(poly_func_type);
let (params, constraints, r#type) = this.export_poly_func_type(poly_func_type);
let decl = this.bump.alloc(model::OperationDecl {
name,
params,
constraints,
r#type,
});
decl
Expand Down Expand Up @@ -671,22 +690,36 @@ impl<'a> Context<'a> {
regions.into_bump_slice()
}

/// Exports a polymorphic function type.
///
/// The returned triple consists of:
/// - The static parameters of the polymorphic function type.
/// - The constraints of the polymorphic function type.
/// - The function type itself.
pub fn export_poly_func_type<RV: MaybeRV>(
&mut self,
t: &PolyFuncTypeBase<RV>,
) -> (&'a [model::Param<'a>], model::TermId) {
) -> (&'a [model::Param<'a>], &'a [model::TermId], model::TermId) {
let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump);
let scope = self
.local_scope
.expect("exporting poly func type outside of local scope");

for (i, param) in t.params().iter().enumerate() {
let name = self.bump.alloc_str(&i.to_string());
let r#type = self.export_type_param(param);
let param = model::Param::Implicit { name, r#type };
let r#type = self.export_type_param(param, Some(model::LocalRef::Index(scope, i as _)));
let param = model::Param {
name,
r#type,
sort: model::ParamSort::Implicit,
};
params.push(param)
}

let constraints = self.bump.alloc_slice_copy(&self.local_constraints);
let body = self.export_func_type(t.body());

(params.into_bump_slice(), body)
(params.into_bump_slice(), constraints, body)
}

pub fn export_type<RV: MaybeRV>(&mut self, t: &TypeBase<RV>) -> model::TermId {
Expand All @@ -703,7 +736,6 @@ impl<'a> Context<'a> {
}
TypeEnum::Function(func) => self.export_func_type(func),
TypeEnum::Variable(index, _) => {
// This ignores the type bound for now
let node = self.local_scope.expect("local variable out of scope");
self.make_term(model::Term::Var(model::LocalRef::Index(node, *index as _)))
}
Expand Down Expand Up @@ -794,20 +826,39 @@ impl<'a> Context<'a> {
self.make_term(model::Term::List { items, tail: None })
}

pub fn export_type_param(&mut self, t: &TypeParam) -> model::TermId {
/// Exports a `TypeParam` to a term.
///
/// The `var` argument is set when the type parameter being exported is the
/// type of a parameter to a polymorphic definition. In that case we can
/// generate a `nonlinear` constraint for the type of runtime types marked as
/// `TypeBound::Copyable`.
pub fn export_type_param(
&mut self,
t: &TypeParam,
var: Option<model::LocalRef<'static>>,
) -> model::TermId {
match t {
// This ignores the type bound for now.
TypeParam::Type { .. } => self.make_term(model::Term::Type),
// This ignores the type bound for now.
TypeParam::Type { b } => {
if let (Some(var), TypeBound::Copyable) = (var, b) {
let term = self.make_term(model::Term::Var(var));
let non_linear = self.make_term(model::Term::NonLinearConstraint { term });
self.local_constraints.push(non_linear);
}

self.make_term(model::Term::Type)
}
// This ignores the bound on the natural for now.
TypeParam::BoundedNat { .. } => self.make_term(model::Term::NatType),
TypeParam::String => self.make_term(model::Term::StrType),
TypeParam::List { param } => {
let item_type = self.export_type_param(param);
let item_type = self.export_type_param(param, None);
self.make_term(model::Term::ListType { item_type })
}
TypeParam::Tuple { params } => {
let items = self.bump.alloc_slice_fill_iter(
params.iter().map(|param| self.export_type_param(param)),
params
.iter()
.map(|param| self.export_type_param(param, None)),
);
let types = self.make_term(model::Term::List { items, tail: None });
self.make_term(model::Term::ApplyFull {
Expand Down
1 change: 1 addition & 0 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ lazy_static! {
NoopDef.add_to_extension(&mut prelude).unwrap();
LiftDef.add_to_extension(&mut prelude).unwrap();
array::ArrayOpDef::load_all_ops(&mut prelude).unwrap();
array::ArrayScanDef.add_to_extension(&mut prelude).unwrap();
prelude
};
/// An extension registry containing only the prelude
Expand Down
Loading

0 comments on commit efdf516

Please sign in to comment.