Skip to content

Commit

Permalink
Fixed bugs with extension sets, and small improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
zrho committed Oct 4, 2024
1 parent e18066d commit 8070fd1
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 58 deletions.
59 changes: 31 additions & 28 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use crate::{
use bumpalo::{collections::Vec as BumpVec, Bump};
use hugr_model::v0::{self as model};
use indexmap::IndexSet;
use smol_str::ToSmolStr;

pub(crate) const OP_FUNC_CALL_INDIRECT: &str = "func.call-indirect";
const TERM_PARAM_TUPLE: &str = "param.tuple";
Expand Down Expand Up @@ -635,41 +634,45 @@ impl<'a> Context<'a> {
// Until we have a better representation for extension sets, we therefore
// need to try and parse each extension as a number to determine if it is
// a variable or an extension.
let mut extensions = Vec::new();
let mut variables = Vec::new();
println!("ext set: {:?}", t);

// NOTE: This overprovisions the capacity since some of the entries of the row
// may be variables. Since we panic when there is more than one variable, this
// may at most waste one slot. That is way better than having to allocate
// a temporary vector.
//
// Also `ExtensionSet` has no way of reporting its size, so we have to count
// the elements by iterating over them...
let capacity = t.iter().count();
let mut extensions = BumpVec::with_capacity_in(capacity, self.bump);
let mut rest = None;

for ext in t.iter() {
if let Ok(index) = ext.parse::<usize>() {
variables.push({
// Extension sets in the model support at most one variable. This is a
// deliberate limitation so that extension sets behave like polymorphic rows.
// The type theory of such rows and how to apply them to model (co)effects
// is well understood.
//
// Extension sets in `hugr-core` at this point have no such restriction.
// However, it appears that so far we never actually use extension sets with
// multiple variables, except for extension sets that are generated through
// property testing.
if rest.is_some() {
// TODO: We won't need this anymore once we have a core representation
// that ensures that extension sets have at most one variable.
panic!("Extension set with multiple variables")
}

rest = Some(
self.module
.insert_term(model::Term::Var(model::LocalRef::Index(index as _)))
});
.insert_term(model::Term::Var(model::LocalRef::Index(index as _))),
);
} else {
extensions.push(ext.to_smolstr());
extensions.push(self.bump.alloc_str(ext) as &str);
}
}

// Extension sets in the model support at most one variable. This is a
// deliberate limitation so that extension sets behave like polymorphic rows.
// The type theory of such rows and how to apply them to model (co)effects
// is well understood.
//
// Extension sets in `hugr-core` at this point have no such restriction.
// However, it appears that so far we never actually use extension sets with
// multiple variables, except for extension sets that are generated through
// property testing.
let rest = match variables.as_slice() {
[] => None,
[var] => Some(*var),
_ => {
// TODO: We won't need this anymore once we have a core representation
// that ensures that extension sets have at most one variable.
panic!("Extension set with multiple variables")
}
};

let mut extensions = BumpVec::with_capacity_in(extensions.len(), self.bump);
extensions.extend(t.iter().map(|ext| self.bump.alloc_str(ext) as &str));
let extensions = extensions.into_bump_slice();

self.module
Expand Down
31 changes: 19 additions & 12 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ struct Context<'a> {
}

impl<'a> Context<'a> {
/// Get the types of the given ports and assemble them into a `TypeRow`.
fn get_port_types(&mut self, ports: &[model::Port]) -> Result<TypeRow, ImportError> {
let types = ports
.iter()
Expand Down Expand Up @@ -198,8 +199,9 @@ impl<'a> Context<'a> {
/// Associate links with the ports of the given node in the given direction.
fn record_links(&mut self, node: Node, direction: Direction, ports: &'a [model::Port]) {
let optype = self.hugr.get_optype(node);
let port_count = optype.port_count(direction);
assert!(ports.len() <= port_count);

// NOTE: `OpType::port_count` copies the signature, which significantly slows down the import.
debug_assert!(ports.len() <= optype.port_count(direction));

for (model_port, port) in ports.iter().zip(self.hugr.node_ports(node, direction)) {
self.link_ports
Expand Down Expand Up @@ -508,6 +510,11 @@ impl<'a> Context<'a> {

let (extension, name) = self.import_custom_name(name)?;

// TODO: Currently we do not have the description or any other metadata for
// the custom op. This will improve with declarative extensions being able
// to declare operations as a node, in which case the description will be attached
// to that node as metadata.

let optype = OpType::OpaqueOp(OpaqueOp::new(
extension,
name,
Expand Down Expand Up @@ -583,7 +590,7 @@ impl<'a> Context<'a> {
) -> Result<(), ImportError> {
let region_data = self.get_region(region)?;

if !matches!(region_data.kind, model::RegionKind::DataFlow) {
if region_data.kind != model::RegionKind::DataFlow {
return Err(model::ModelError::InvalidRegions(node_id).into());
}

Expand Down Expand Up @@ -629,7 +636,7 @@ impl<'a> Context<'a> {
parent: Node,
) -> Result<Node, ImportError> {
let node_data = self.get_node(node_id)?;
assert!(matches!(node_data.operation, model::Operation::TailLoop));
debug_assert_eq!(node_data.operation, model::Operation::TailLoop);

let [region] = node_data.regions else {
return Err(model::ModelError::InvalidRegions(node_id).into());
Expand All @@ -641,6 +648,7 @@ impl<'a> Context<'a> {
let (just_inputs, just_outputs) = {
let mut sum_rows = sum_rows.into_iter();

// NOTE: This can not fail since else `import_adt_and_rest` would have failed before.
let term = region_data.targets[0].r#type.unwrap();

let Some(just_inputs) = sum_rows.next() else {
Expand Down Expand Up @@ -673,7 +681,7 @@ impl<'a> Context<'a> {
parent: Node,
) -> Result<Node, ImportError> {
let node_data = self.get_node(node_id)?;
assert!(matches!(node_data.operation, model::Operation::Conditional));
debug_assert_eq!(node_data.operation, model::Operation::Conditional);

let (sum_rows, other_inputs) = self.import_adt_and_rest(node_id, node_data.inputs)?;
let outputs = self.get_port_types(node_data.outputs)?;
Expand Down Expand Up @@ -820,7 +828,7 @@ impl<'a> Context<'a> {
) -> Result<(), ImportError> {
let region_data = self.get_region(region)?;

if !matches!(region_data.kind, model::RegionKind::ControlFlow) {
if region_data.kind != model::RegionKind::ControlFlow {
return Err(model::ModelError::InvalidRegions(node_id).into());
}

Expand All @@ -841,7 +849,7 @@ impl<'a> Context<'a> {
parent: Node,
) -> Result<Node, ImportError> {
let node_data = self.get_node(node_id)?;
assert!(matches!(node_data.operation, model::Operation::Block));
debug_assert_eq!(node_data.operation, model::Operation::Block);

let [region] = node_data.regions else {
return Err(model::ModelError::InvalidRegions(node_id).into());
Expand Down Expand Up @@ -923,8 +931,7 @@ impl<'a> Context<'a> {
model::Term::StrType => Ok(TypeParam::String),
model::Term::ExtSetType => Ok(TypeParam::Extensions),

// TODO: What do we do about the bounds on naturals?
model::Term::NatType => todo!(),
model::Term::NatType => Ok(TypeParam::max_nat()),

model::Term::Nat(_)
| model::Term::Str(_)
Expand Down Expand Up @@ -1101,16 +1108,16 @@ impl<'a> Context<'a> {
let model::Term::FuncType {
inputs,
outputs,
extensions: _,
extensions,
} = term
else {
return Err(model::ModelError::TypeError(term_id).into());
};

let inputs = self.import_type_row::<RV>(*inputs)?;
let outputs = self.import_type_row::<RV>(*outputs)?;
// TODO: extensions
Ok(FuncTypeBase::new(inputs, outputs))
let extensions = self.import_extension_set(*extensions)?;
Ok(FuncTypeBase::new(inputs, outputs).with_extension_delta(extensions))
}

fn import_closed_list(
Expand Down
9 changes: 5 additions & 4 deletions hugr-core/tests/fixtures/model-call.edn
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
(hugr 0)

(declare-func example.callee
[(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)
(forall ?ext ext-set)
[(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int . ?ext)
(meta doc.title "Callee")
(meta doc.description "This is a function declaration."))

(define-func example.caller
[(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)
[(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int)
(meta doc.title "Caller")
(meta doc.description "This defines a function that calls the function which we declared earlier.")
(dfg
[(%3 (@ arithmetic.int.types.int))]
[(%4 (@ arithmetic.int.types.int))]
(call (@ example.callee) [(%3 (@ arithmetic.int.types.int))] [(%4 (@ arithmetic.int.types.int))])))
(call (@ example.callee (ext)) [(%3 (@ arithmetic.int.types.int))] [(%4 (@ arithmetic.int.types.int))])))

(define-func example.load
[] [(fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))] (ext)
(dfg
[]
[(%5 (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))]
(load-func (@ example.caller) [] [(%5 (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))])))
(load-func (@ example.caller) [] [(%5 (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int)))])))
13 changes: 9 additions & 4 deletions hugr-core/tests/snapshots/model__roundtrip_call.snap
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@ expression: "roundtrip(include_str!(\"fixtures/model-call.edn\"))"
(hugr 0)

(declare-func example.callee
[(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))
(forall ?0 ext-set)
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int . ?0))

(define-func example.caller
[(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int)
(dfg
[(%0 (@ arithmetic.int.types.int))]
[(%1 (@ arithmetic.int.types.int))]
(call
(@ example.callee)
(@ example.callee (ext))
[(%0 (@ arithmetic.int.types.int))]
[(%1 (@ arithmetic.int.types.int))])))

Expand All @@ -35,4 +40,4 @@ expression: "roundtrip(include_str!(\"fixtures/model-call.edn\"))"
(fn
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext)))])))
(ext arithmetic.int)))])))
3 changes: 2 additions & 1 deletion hugr-model/src/v0/text/hugr.pest
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
WHITESPACE = _{ " " | "\t" | "\r" | "\n" }
COMMENT = _{ ";" ~ (!("\n") ~ ANY)* ~ "\n" }
identifier = @{ (ASCII_ALPHA | "_" | "-") ~ (ASCII_ALPHANUMERIC | "_" | "-")* }
ext_name = @{ identifier ~ ("." ~ identifier)* }
symbol = @{ identifier ~ ("." ~ identifier)+ }
tag = @{ (ASCII_NONZERO_DIGIT ~ ASCII_DIGIT*) | "0" }

Expand Down Expand Up @@ -97,7 +98,7 @@ term_str = { string }
term_str_type = { "str" }
term_nat = { (ASCII_DIGIT)+ }
term_nat_type = { "nat" }
term_ext_set = { "(" ~ "ext" ~ identifier* ~ (list_tail ~ term)? ~ ")" }
term_ext_set = { "(" ~ "ext" ~ ext_name* ~ (list_tail ~ term)? ~ ")" }
term_ext_set_type = { "ext-set" }
term_adt = { "(" ~ "adt" ~ term ~ ")" }
term_func_type = { "(" ~ "fn" ~ term ~ term ~ term ~ ")" }
Expand Down
11 changes: 2 additions & 9 deletions hugr-model/src/v0/text/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ impl<'a> ParseContext<'a> {
let mut extensions = Vec::new();
let mut rest = None;

for token in filter_rule(&mut inner, Rule::identifier) {
for token in filter_rule(&mut inner, Rule::ext_name) {
extensions.push(token.as_str());
}

Expand Down Expand Up @@ -525,14 +525,7 @@ impl<'a> ParseContext<'a> {

let inputs = self.parse_term(inner.next().unwrap())?;
let outputs = self.parse_term(inner.next().unwrap())?;

let extensions = match inner.peek().map(|p| p.as_rule()) {
Some(Rule::term_ext_set) => self.parse_term(inner.next().unwrap())?,
_ => self.module.insert_term(Term::ExtSet {
extensions: &[],
rest: None,
}),
};
let extensions = self.parse_term(inner.next().unwrap())?;

// Assemble the inputs, outputs and extensions into a function type.
let func = self.module.insert_term(Term::FuncType {
Expand Down

0 comments on commit 8070fd1

Please sign in to comment.