diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index 84b83677f..ce779787d 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -20,6 +20,7 @@ workspace = true extension_inference = [] declarative = ["serde_yaml"] model_unstable = ["hugr-model"] +default = ["model_unstable"] [lib] bench = false diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 06be24fa9..36b65e064 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -14,11 +14,8 @@ use crate::{ use bumpalo::{collections::String as BumpString, collections::Vec as BumpVec, Bump}; use fxhash::FxHashMap; use hugr_model::v0::{self as model}; -use indexmap::IndexSet; use std::fmt::Write; -type FxIndexSet = IndexSet; - pub(crate) const OP_FUNC_CALL_INDIRECT: &str = "func.call-indirect"; const TERM_PARAM_TUPLE: &str = "param.tuple"; const TERM_JSON: &str = "prelude.json"; @@ -37,9 +34,6 @@ struct Context<'a> { hugr: &'a Hugr, /// The module that is being built. module: model::Module<'a>, - /// Mapping from ports to link indices. - /// This only includes the minimum port among groups of linked ports. - links: FxIndexSet<(Node, Port)>, /// The arena in which the model is allocated. bump: &'a Bump, /// Stores the terms that we have already seen to avoid duplicates. @@ -61,6 +55,16 @@ struct Context<'a> { /// Mapping from extension operations to their declarations. decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>, + + links: model::scope::LinkTable<(Node, Port)>, + + /// The symbol table tracking symbols that are currently in scope. + symbols: model::scope::SymbolTable<'a>, + + /// Mapping from implicit imports to their node ids. + implicit_imports: FxHashMap<&'a str, model::NodeId>, + + node_indices: FxHashMap, } impl<'a> Context<'a> { @@ -72,49 +76,68 @@ impl<'a> Context<'a> { hugr, module, bump, - links: IndexSet::default(), term_map: FxHashMap::default(), local_scope: None, decl_operations: FxHashMap::default(), local_constraints: Vec::new(), + symbols: model::scope::SymbolTable::default(), + implicit_imports: FxHashMap::default(), + node_indices: FxHashMap::default(), + links: model::scope::LinkTable::default(), } } /// Exports the root module of the HUGR graph. pub fn export_root(&mut self) { + self.module.root = self.module.insert_region(model::Region::default()); + self.symbols.enter(self.module.root); + self.links.enter(self.module.root); + let hugr_children = self.hugr.children(self.hugr.root()); let mut children = Vec::with_capacity(hugr_children.size_hint().0); - for child in self.hugr.children(self.hugr.root()) { - children.push(self.export_node(child)); + for child in hugr_children.clone() { + children.push(self.export_node_shallow(child)); } - children.extend(self.decl_operations.values().copied()); + for (child, child_node_id) in hugr_children.zip(children.iter().copied()) { + self.export_node_deep(child, child_node_id); + } + + let mut all_children = BumpVec::with_capacity_in( + children.len() + self.decl_operations.len() + self.implicit_imports.len(), + self.bump, + ); + + all_children.extend(self.implicit_imports.drain().map(|(_, id)| id)); + all_children.extend(self.decl_operations.values().copied()); + all_children.extend(children); - let root = self.module.insert_region(model::Region { + let (links, ports) = self.links.exit(); + self.symbols.exit(); + + self.module.regions[self.module.root.index()] = model::Region { kind: model::RegionKind::Module, sources: &[], targets: &[], - children: self.bump.alloc_slice_copy(&children), + children: all_children.into_bump_slice(), meta: &[], // TODO: Export metadata signature: None, - }); - - self.module.root = root; + scope: Some(model::RegionScope { links, ports }), + }; } /// Returns the edge id for a given port, creating a new edge if necessary. /// /// Any two ports that are linked will be represented by the same link. - fn get_link_id(&mut self, node: Node, port: impl Into) -> model::LinkId { + fn get_link_index(&mut self, node: Node, port: impl Into) -> model::LinkIndex { // To ensure that linked ports are represented by the same edge, we take the minimum port // among all the linked ports, including the one we started with. let port = port.into(); let linked_ports = self.hugr.linked_ports(node, port); let all_ports = std::iter::once((node, port)).chain(linked_ports); let repr = all_ports.min().unwrap(); - let edge = self.links.insert_full(repr).0 as _; - model::LinkId(edge) + self.links.use_link(repr) } pub fn make_ports( @@ -122,12 +145,12 @@ impl<'a> Context<'a> { node: Node, direction: Direction, num_ports: usize, - ) -> &'a [model::LinkRef<'a>] { + ) -> &'a [model::LinkIndex] { let ports = self.hugr.node_ports(node, direction); let mut links = BumpVec::with_capacity_in(ports.size_hint().0, self.bump); for port in ports.take(num_ports) { - links.push(model::LinkRef::Id(self.get_link_id(node, port))); + links.push(self.get_link_index(node, port)); } links.into_bump_slice() @@ -160,8 +183,9 @@ impl<'a> Context<'a> { &mut self, extension: &IdentList, name: impl AsRef, - ) -> model::GlobalRef<'a> { - model::GlobalRef::Named(self.make_qualified_name(extension, name)) + ) -> model::NodeId { + let symbol = self.make_qualified_name(extension, name); + self.resolve_symbol(symbol) } /// Get the node that declares or defines the function associated with the given @@ -195,12 +219,31 @@ impl<'a> Context<'a> { result } - pub fn export_node(&mut self, node: Node) -> model::NodeId { + fn export_node_shallow(&mut self, node: Node) -> model::NodeId { + let node_id = self.module.insert_node(model::Node::default()); + self.node_indices.insert(node, node_id); + + let symbol = match self.hugr.get_optype(node) { + OpType::FuncDefn(func_defn) => Some(func_defn.name.as_str()), + OpType::FuncDecl(func_decl) => Some(func_decl.name.as_str()), + OpType::AliasDecl(alias_decl) => Some(alias_decl.name.as_str()), + OpType::AliasDefn(alias_defn) => Some(alias_defn.name.as_str()), + _ => None, + }; + + if let Some(symbol) = symbol { + self.symbols + .insert(symbol, node_id) + .expect("duplicate symbol"); + } + + node_id + } + + fn export_node_deep(&mut self, node: Node, node_id: model::NodeId) { // We insert a dummy node with the invalid operation at this point to reserve // the node id. This is necessary to establish the correct node id for the // local scope introduced by some operations. We will overwrite this node later. - let node_id = self.module.insert_node(model::Node::default()); - let mut params: &[_] = &[]; let mut regions: &[_] = &[]; @@ -221,12 +264,12 @@ impl<'a> Context<'a> { let extensions = self.export_ext_set(&dfg.signature.extension_reqs); regions = self .bump - .alloc_slice_copy(&[self.export_dfg(node, extensions)]); + .alloc_slice_copy(&[self.export_dfg(node, extensions, false)]); model::Operation::Dfg } OpType::CFG(_) => { - regions = self.bump.alloc_slice_copy(&[self.export_cfg(node)]); + regions = self.bump.alloc_slice_copy(&[self.export_cfg(node, false)]); model::Operation::Cfg } @@ -242,7 +285,7 @@ impl<'a> Context<'a> { let extensions = self.export_ext_set(&block.extension_delta); regions = self .bump - .alloc_slice_copy(&[self.export_dfg(node, extensions)]); + .alloc_slice_copy(&[self.export_dfg(node, extensions, false)]); model::Operation::Block } @@ -258,7 +301,7 @@ impl<'a> Context<'a> { let extensions = this.export_ext_set(&func.signature.body().extension_reqs); regions = this .bump - .alloc_slice_copy(&[this.export_dfg(node, extensions)]); + .alloc_slice_copy(&[this.export_dfg(node, extensions, true)]); model::Operation::DefineFunc { decl } }), @@ -300,26 +343,25 @@ impl<'a> Context<'a> { OpType::Call(call) => { // TODO: If the node is not connected to a function, we should do better than panic. let node = self.connected_function(node).unwrap(); - let name = model::GlobalRef::Named(self.get_func_name(node).unwrap()); - + let symbol = self.node_indices[&node]; let mut args = BumpVec::new_in(self.bump); args.extend(call.type_args.iter().map(|arg| self.export_type_arg(arg))); let args = args.into_bump_slice(); - let func = self.make_term(model::Term::ApplyFull { global: name, args }); + let func = self.make_term(model::Term::ApplyFull { symbol, args }); model::Operation::CallFunc { func } } OpType::LoadFunction(load) => { // TODO: If the node is not connected to a function, we should do better than panic. let node = self.connected_function(node).unwrap(); - let name = model::GlobalRef::Named(self.get_func_name(node).unwrap()); + let symbol = self.node_indices[&node]; let mut args = BumpVec::new_in(self.bump); args.extend(load.type_args.iter().map(|arg| self.export_type_arg(arg))); let args = args.into_bump_slice(); - let func = self.make_term(model::Term::ApplyFull { global: name, args }); + let func = self.make_term(model::Term::ApplyFull { symbol, args }); model::Operation::LoadFunc { func } } @@ -327,7 +369,7 @@ impl<'a> Context<'a> { OpType::LoadConstant(_) => todo!("Export load constant?"), OpType::CallIndirect(_) => model::Operation::CustomFull { - operation: model::GlobalRef::Named(OP_FUNC_CALL_INDIRECT), + operation: self.resolve_symbol(OP_FUNC_CALL_INDIRECT), }, OpType::Tag(tag) => model::Operation::Tag { tag: tag.tag as _ }, @@ -336,7 +378,7 @@ impl<'a> Context<'a> { let extensions = self.export_ext_set(&tail_loop.extension_delta); regions = self .bump - .alloc_slice_copy(&[self.export_dfg(node, extensions)]); + .alloc_slice_copy(&[self.export_dfg(node, extensions, false)]); model::Operation::TailLoop } @@ -361,7 +403,7 @@ impl<'a> Context<'a> { // as that of the node. This might change in the future. let extensions = self.export_ext_set(&op.extension_delta()); - if let Some(region) = self.export_dfg_if_present(node, extensions) { + if let Some(region) = self.export_dfg_if_present(node, extensions, true) { regions = self.bump.alloc_slice_copy(&[region]); } @@ -381,7 +423,7 @@ impl<'a> Context<'a> { // as that of the node. This might change in the future. let extensions = self.export_ext_set(&op.extension_delta()); - if let Some(region) = self.export_dfg_if_present(node, extensions) { + if let Some(region) = self.export_dfg_if_present(node, extensions, true) { regions = self.bump.alloc_slice_copy(&[region]); } @@ -417,8 +459,7 @@ impl<'a> Context<'a> { None => &[], }; - // Replace the placeholder node with the actual node. - *self.module.get_node_mut(node_id).unwrap() = model::Node { + self.module.nodes[node_id.index()] = model::Node { operation, inputs, outputs, @@ -427,8 +468,6 @@ impl<'a> Context<'a> { meta, signature, }; - - node_id } /// Export an `OpDef` as an operation declaration. @@ -438,7 +477,7 @@ impl<'a> Context<'a> { /// of the operation. The node is added to the `decl_operations` map so that /// at the end of the export, the operation declaration nodes can be added /// to the module as children of the module region. - pub fn export_opdef(&mut self, opdef: &OpDef) -> model::GlobalRef<'a> { + pub fn export_opdef(&mut self, opdef: &OpDef) -> model::NodeId { use std::collections::hash_map::Entry; let poly_func_type = match opdef.signature_func() { @@ -450,9 +489,7 @@ impl<'a> Context<'a> { let entry = self.decl_operations.entry(key); let node = match entry { - Entry::Occupied(occupied_entry) => { - return model::GlobalRef::Direct(*occupied_entry.get()) - } + Entry::Occupied(occupied_entry) => return *occupied_entry.get(), Entry::Vacant(vacant_entry) => { *vacant_entry.insert(self.module.insert_node(model::Node { operation: model::Operation::Invalid, @@ -502,7 +539,7 @@ impl<'a> Context<'a> { node_data.operation = model::Operation::DeclareOperation { decl }; node_data.meta = meta; - model::GlobalRef::Direct(node) + node } /// Export the signature of a `DataflowBlock`. Here we can't use `OpType::dataflow_signature` @@ -545,18 +582,41 @@ impl<'a> Context<'a> { &mut self, node: Node, extensions: model::TermId, + link_scope_closed: bool, ) -> Option { if self.hugr.children(node).next().is_none() { None } else { - Some(self.export_dfg(node, extensions)) + Some(self.export_dfg(node, extensions, link_scope_closed)) } } /// Creates a data flow region from the given node's children. /// /// `Input` and `Output` nodes are used to determine the source and target ports of the region. - pub fn export_dfg(&mut self, node: Node, extensions: model::TermId) -> model::RegionId { + pub fn export_dfg( + &mut self, + node: Node, + extensions: model::TermId, + link_scope_closed: bool, + ) -> model::RegionId { + let region = self.module.insert_region(model::Region::default()); + + self.symbols.enter(region); + if link_scope_closed { + self.links.enter(region); + } + + let region_children = { + let children = self.hugr.children(node); + let mut region_children = + BumpVec::with_capacity_in(children.size_hint().0 - 2, self.bump); + for child in children.skip(2) { + region_children.push(self.export_node_shallow(child)); + } + region_children.into_bump_slice() + }; + let mut children = self.hugr.children(node); // The first child is an `Input` node, which we use to determine the region's sources. @@ -574,10 +634,8 @@ impl<'a> Context<'a> { let targets = self.make_ports(output_node, Direction::Incoming, output_op.types.len()); // Export the remaining children of the node. - let mut region_children = BumpVec::with_capacity_in(children.size_hint().0, self.bump); - - for child in children { - region_children.push(self.export_node(child)); + for (child, child_node_id) in children.zip(region_children.iter().copied()) { + self.export_node_deep(child, child_node_id); } let signature = { @@ -591,62 +649,112 @@ impl<'a> Context<'a> { })) }; - self.module.insert_region(model::Region { + let scope = if link_scope_closed { + let (links, ports) = self.links.exit(); + Some(model::RegionScope { links, ports }) + } else { + None + }; + self.symbols.exit(); + + self.module.regions[region.index()] = model::Region { kind: model::RegionKind::DataFlow, sources, targets, - children: region_children.into_bump_slice(), + children: region_children, meta: &[], // TODO: Export metadata signature, - }) + scope, + }; + + region } /// Creates a control flow region from the given node's children. - pub fn export_cfg(&mut self, node: Node) -> model::RegionId { - let mut children = self.hugr.children(node); - let mut region_children = BumpVec::with_capacity_in(children.size_hint().0 + 1, self.bump); + pub fn export_cfg(&mut self, node: Node, link_scope_closed: bool) -> model::RegionId { + let region = self.module.insert_region(model::Region::default()); + self.symbols.enter(region); + + if link_scope_closed { + self.links.enter(region); + } + + let region_children = { + let children = self.hugr.children(node); + let mut region_children = + BumpVec::with_capacity_in(children.size_hint().0 - 1, self.bump); + + // First export the children shallowly to allocate their IDs and register symbols. + for (i, child) in children.enumerate() { + // The second node is the exit block, which is not exported as a node itself. + if i == 1 { + continue; + } + + region_children.push(self.export_node_shallow(child)); + } + + region_children.into_bump_slice() + }; + + let mut children_iter = self.hugr.children(node); + let mut region_children_iter = region_children.iter().copied(); // The first child is the entry block. // We create a source port on the control flow region and connect it to the // first input port of the exported entry block. - let entry_block = children.next().unwrap(); + let source = { + let entry_block = children_iter.next().unwrap(); + let entry_node_id = region_children_iter.next().unwrap(); - let OpType::DataflowBlock(_) = self.hugr.get_optype(entry_block) else { - panic!("expected a `DataflowBlock` node as the first child node"); - }; + let OpType::DataflowBlock(_) = self.hugr.get_optype(entry_block) else { + panic!("expected a `DataflowBlock` node as the first child node"); + }; - let source = model::LinkRef::Id(self.get_link_id(entry_block, IncomingPort::from(0))); - region_children.push(self.export_node(entry_block)); + self.export_node_deep(entry_block, entry_node_id); + self.get_link_index(entry_block, IncomingPort::from(0)) + }; - // The last child is the exit block. + // The second child is the exit block. // Contrary to the entry block, the exit block does not have a dataflow subgraph. // We therefore do not export the block itself, but simply use its output ports // as the target ports of the control flow region. - let exit_block = children.next_back().unwrap(); - - // Export the remaining children of the node, except for the last one. - for child in children { - region_children.push(self.export_node(child)); - } + let exit_block = children_iter.next_back().unwrap(); let OpType::ExitBlock(_) = self.hugr.get_optype(exit_block) else { - panic!("expected an `ExitBlock` node as the last child node"); + panic!("expected an `ExitBlock` node as the second child node"); }; + // Export the remaining children of the node, except for the last one. + for (child, child_node_id) in children_iter.zip(region_children_iter) { + self.export_node_deep(child, child_node_id); + } + let targets = self.make_ports(exit_block, Direction::Incoming, 1); // Get the signature of the control flow region. // This is the same as the signature of the parent node. let signature = Some(self.export_func_type(&self.hugr.signature(node).unwrap())); - self.module.insert_region(model::Region { + let scope = if link_scope_closed { + let (links, ports) = self.links.exit(); + Some(model::RegionScope { links, ports }) + } else { + None + }; + self.symbols.exit(); + + self.module.regions[region.index()] = model::Region { kind: model::RegionKind::ControlFlow, sources: self.bump.alloc_slice_copy(&[source]), targets, - children: region_children.into_bump_slice(), + children: region_children, meta: &[], // TODO: Export metadata signature, - }) + scope, + }; + + region } /// Export the `Case` node children of a `Conditional` node as data flow regions. @@ -660,7 +768,7 @@ impl<'a> Context<'a> { }; let extensions = self.export_ext_set(&case_op.signature.extension_reqs); - regions.push(self.export_dfg(child, extensions)); + regions.push(self.export_dfg(child, extensions, false)); } regions.into_bump_slice() @@ -683,7 +791,7 @@ impl<'a> Context<'a> { for (i, param) in t.params().iter().enumerate() { let name = self.bump.alloc_str(&i.to_string()); - let r#type = self.export_type_param(param, Some(model::LocalRef::Index(scope, i as _))); + let r#type = self.export_type_param(param, Some((scope, i as _))); let param = model::Param { name, r#type, @@ -706,14 +814,17 @@ impl<'a> Context<'a> { match t { TypeEnum::Extension(ext) => self.export_custom_type(ext), TypeEnum::Alias(alias) => { - let name = model::GlobalRef::Named(self.bump.alloc_str(alias.name())); + let global = self.resolve_symbol(self.bump.alloc_str(alias.name())); let args = &[]; - self.make_term(model::Term::ApplyFull { global: name, args }) + self.make_term(model::Term::ApplyFull { + symbol: global, + args, + }) } TypeEnum::Function(func) => self.export_func_type(func), TypeEnum::Variable(index, _) => { let node = self.local_scope.expect("local variable out of scope"); - self.make_term(model::Term::Var(model::LocalRef::Index(node, *index as _))) + self.make_term(model::Term::Var(model::VarId(node, *index as _))) } TypeEnum::RowVar(rv) => self.export_row_var(rv.as_rv()), TypeEnum::Sum(sum) => self.export_sum_type(sum), @@ -732,12 +843,12 @@ impl<'a> Context<'a> { } pub fn export_custom_type(&mut self, t: &CustomType) -> model::TermId { - let global = self.make_named_global_ref(t.extension(), t.name()); + let symbol = self.make_named_global_ref(t.extension(), t.name()); let args = self .bump .alloc_slice_fill_iter(t.args().iter().map(|p| self.export_type_arg(p))); - let term = model::Term::ApplyFull { global, args }; + let term = model::Term::ApplyFull { symbol, args }; self.make_term(term) } @@ -762,15 +873,12 @@ impl<'a> Context<'a> { pub fn export_type_arg_var(&mut self, var: &TypeArgVariable) -> model::TermId { let node = self.local_scope.expect("local variable out of scope"); - self.make_term(model::Term::Var(model::LocalRef::Index( - node, - var.index() as _, - ))) + self.make_term(model::Term::Var(model::VarId(node, var.index() as _))) } pub fn export_row_var(&mut self, t: &RowVariable) -> model::TermId { let node = self.local_scope.expect("local variable out of scope"); - self.make_term(model::Term::Var(model::LocalRef::Index(node, t.0 as _))) + self.make_term(model::Term::Var(model::VarId(node, t.0 as _))) } pub fn export_sum_type(&mut self, t: &SumType) -> model::TermId { @@ -834,12 +942,12 @@ impl<'a> Context<'a> { pub fn export_type_param( &mut self, t: &TypeParam, - var: Option>, + var: Option<(model::NodeId, model::VarIndex)>, ) -> model::TermId { match t { TypeParam::Type { b } => { - if let (Some(var), TypeBound::Copyable) = (var, b) { - let term = self.make_term(model::Term::Var(var)); + if let (Some((node, index)), TypeBound::Copyable) = (var, b) { + let term = self.make_term(model::Term::Var(model::VarId(node, index))); let non_linear = self.make_term(model::Term::NonLinearConstraint { term }); self.local_constraints.push(non_linear); } @@ -860,8 +968,9 @@ impl<'a> Context<'a> { .map(|param| model::ListPart::Item(self.export_type_param(param, None))), ); let types = self.make_term(model::Term::List { parts }); + let symbol = self.resolve_symbol(TERM_PARAM_TUPLE); self.make_term(model::Term::ApplyFull { - global: model::GlobalRef::Named(TERM_PARAM_TUPLE), + symbol, args: self.bump.alloc_slice_copy(&[types]), }) } @@ -879,10 +988,9 @@ impl<'a> Context<'a> { for ext in ext_set.iter() { // `ExtensionSet`s represent variables by extension names that parse to integers. match ext.parse::() { - Ok(var) => { + Ok(index) => { let node = self.local_scope.expect("local variable out of scope"); - let local_ref = model::LocalRef::Index(node, var); - let term = self.make_term(model::Term::Var(local_ref)); + let term = self.make_term(model::Term::Var(model::VarId(node, index))); parts.push(model::ExtSetPart::Splice(term)); } Err(_) => parts.push(model::ExtSetPart::Extension(self.bump.alloc_str(ext))), @@ -913,11 +1021,26 @@ impl<'a> Context<'a> { let value = serde_json::to_string(value).expect("json values are always serializable"); let value = self.make_term(model::Term::Str(self.bump.alloc_str(&value))); let value = self.bump.alloc_slice_copy(&[value]); + let symbol = self.resolve_symbol(TERM_JSON); self.make_term(model::Term::ApplyFull { - global: model::GlobalRef::Named(TERM_JSON), + symbol, args: value, }) } + + fn resolve_symbol(&mut self, name: &'a str) -> model::NodeId { + let result = self.symbols.resolve(name); + + match result { + Ok(node) => node, + Err(_) => *self.implicit_imports.entry(name).or_insert_with(|| { + self.module.insert_node(model::Node { + operation: model::Operation::Import { name }, + ..model::Node::default() + }) + }), + } + } } #[cfg(test)] diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 002160840..ee2237ffa 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -22,16 +22,13 @@ use crate::{ Direction, Hugr, HugrView, Node, Port, }; use fxhash::FxHashMap; -use hugr_model::v0::{self as model, GlobalRef}; -use indexmap::IndexMap; +use hugr_model::v0::{self as model}; use itertools::Either; use smol_str::{SmolStr, ToSmolStr}; use thiserror::Error; const TERM_JSON: &str = "prelude.json"; -type FxIndexMap = IndexMap; - /// Error during import. #[derive(Debug, Clone, Error)] pub enum ImportError { @@ -76,22 +73,18 @@ pub fn import_hugr( module: &model::Module, extensions: &ExtensionRegistry, ) -> Result { - let names = Names::new(module)?; - // TODO: Module should know about the number of edges, so that we can use a vector here. // For now we use a hashmap, which will be slower. - let edge_ports = FxHashMap::default(); - let mut ctx = Context { module, - names, hugr: Hugr::new(OpType::Module(Module {})), - link_ports: edge_ports, + link_ports: FxHashMap::default(), static_edges: Vec::new(), extensions, nodes: FxHashMap::default(), - local_variables: IndexMap::default(), + local_vars: FxHashMap::default(), custom_name_cache: FxHashMap::default(), + region_scope: model::RegionId::default(), }; ctx.import_root()?; @@ -105,31 +98,28 @@ struct Context<'a> { /// The module being imported. module: &'a model::Module<'a>, - names: Names<'a>, - /// The HUGR graph being constructed. hugr: Hugr, /// The ports that are part of each link. This is used to connect the ports at the end of the /// import process. - link_ports: FxHashMap, Vec<(Node, Port)>>, + link_ports: FxHashMap<(model::RegionId, model::LinkIndex), Vec<(Node, Port)>>, /// Pairs of nodes that should be connected by a static edge. /// These are collected during the import process and connected at the end. static_edges: Vec<(model::NodeId, model::NodeId)>, - // /// The `(Node, Port)` pairs for each `PortId` in the module. - // imported_ports: Vec>, /// The ambient extension registry to use for importing. extensions: &'a ExtensionRegistry, /// A map from `NodeId` to the imported `Node`. nodes: FxHashMap, - /// The local variables that are currently in scope. - local_variables: FxIndexMap<&'a str, LocalVar>, + local_vars: FxHashMap, custom_name_cache: FxHashMap<&'a str, (ExtensionId, SmolStr)>, + + region_scope: model::RegionId, } impl<'a> Context<'a> { @@ -166,25 +156,6 @@ impl<'a> Context<'a> { .ok_or_else(|| model::ModelError::RegionNotFound(region_id).into()) } - /// Looks up a [`LocalRef`] within the current scope. - fn resolve_local_ref( - &self, - local_ref: &model::LocalRef, - ) -> Result<(usize, LocalVar), ImportError> { - let term = match local_ref { - model::LocalRef::Index(_, index) => self - .local_variables - .get_index(*index as usize) - .map(|(_, v)| (*index as usize, *v)), - model::LocalRef::Named(name) => self - .local_variables - .get_full(name) - .map(|(index, _, v)| (index, *v)), - }; - - term.ok_or_else(|| model::ModelError::InvalidLocal(local_ref.to_string()).into()) - } - fn make_node( &mut self, node_id: model::NodeId, @@ -209,13 +180,16 @@ impl<'a> Context<'a> { } /// Associate links with the ports of the given node in the given direction. - fn record_links(&mut self, node: Node, direction: Direction, links: &'a [model::LinkRef<'a>]) { + fn record_links(&mut self, node: Node, direction: Direction, links: &'a [model::LinkIndex]) { let optype = self.hugr.get_optype(node); // NOTE: `OpType::port_count` copies the signature, which significantly slows down the import. debug_assert!(links.len() <= optype.port_count(direction)); for (link, port) in links.iter().zip(self.hugr.node_ports(node, direction)) { - self.link_ports.entry(*link).or_default().push((node, port)); + self.link_ports + .entry((self.region_scope, *link)) + .or_default() + .push((node, port)); } } @@ -242,8 +216,9 @@ impl<'a> Context<'a> { if inputs.is_empty() || outputs.is_empty() { return Err(error_unsupported!( - "link {} is missing either an input or an output port", - link_id + "link {}#{} is missing either an input or an output port", + link_id.0, + link_id.1 )); } @@ -279,60 +254,13 @@ impl<'a> Context<'a> { Ok(()) } - fn with_local_socpe( - &mut self, - f: impl FnOnce(&mut Self) -> Result, - ) -> Result { - let previous = std::mem::take(&mut self.local_variables); - let result = f(self); - self.local_variables = previous; - result - } - - fn resolve_global_ref( - &self, - global_ref: &model::GlobalRef, - ) -> Result { - match global_ref { - model::GlobalRef::Direct(node_id) => Ok(*node_id), - model::GlobalRef::Named(name) => { - let item = self - .names - .items - .get(name) - .ok_or_else(|| model::ModelError::InvalidGlobal(global_ref.to_string()))?; - - match item { - NamedItem::FuncDecl(node) => Ok(*node), - NamedItem::FuncDefn(node) => Ok(*node), - NamedItem::CtrDecl(node) => Ok(*node), - NamedItem::OperationDecl(node) => Ok(*node), - } - } - } - } - - fn get_global_name(&self, global_ref: model::GlobalRef<'a>) -> Result<&'a str, ImportError> { - match global_ref { - model::GlobalRef::Direct(node_id) => { - let node_data = self.get_node(node_id)?; - - let name = match node_data.operation { - model::Operation::DefineFunc { decl } => decl.name, - model::Operation::DeclareFunc { decl } => decl.name, - model::Operation::DefineAlias { decl, .. } => decl.name, - model::Operation::DeclareAlias { decl } => decl.name, - model::Operation::DeclareConstructor { decl } => decl.name, - model::Operation::DeclareOperation { decl } => decl.name, - _ => { - return Err(model::ModelError::InvalidGlobal(global_ref.to_string()).into()); - } - }; - - Ok(name) - } - model::GlobalRef::Named(name) => Ok(name), - } + fn get_global_name(&self, node_id: model::NodeId) -> Result<&'a str, ImportError> { + let node_data = self.get_node(node_id)?; + let name = node_data + .operation + .symbol() + .ok_or(model::ModelError::InvalidSymbol(node_id))?; + Ok(name) } fn get_func_signature( @@ -345,11 +273,12 @@ impl<'a> Context<'a> { _ => return Err(model::ModelError::UnexpectedOperation(func_node).into()), }; - self.import_poly_func_type(*decl, |_, signature| Ok(signature)) + self.import_poly_func_type(func_node, *decl, |_, signature| Ok(signature)) } /// Import the root region of the module. fn import_root(&mut self) -> Result<(), ImportError> { + self.region_scope = self.module.root; let region_data = self.get_region(self.module.root)?; for node in region_data.children { @@ -400,7 +329,7 @@ impl<'a> Context<'a> { } model::Operation::DefineFunc { decl } => { - self.import_poly_func_type(*decl, |ctx, signature| { + self.import_poly_func_type(node_id, *decl, |ctx, signature| { let optype = OpType::FuncDefn(FuncDefn { name: decl.name.to_string(), signature, @@ -419,7 +348,7 @@ impl<'a> Context<'a> { } model::Operation::DeclareFunc { decl } => { - self.import_poly_func_type(*decl, |ctx, signature| { + self.import_poly_func_type(node_id, *decl, |ctx, signature| { let optype = OpType::FuncDecl(FuncDecl { name: decl.name.to_string(), signature, @@ -432,19 +361,18 @@ impl<'a> Context<'a> { } model::Operation::CallFunc { func } => { - let model::Term::ApplyFull { global: name, args } = self.get_term(func)? else { + let model::Term::ApplyFull { symbol, args } = self.get_term(func)? else { return Err(model::ModelError::TypeError(func).into()); }; - let func_node = self.resolve_global_ref(name)?; - let func_sig = self.get_func_signature(func_node)?; + let func_sig = self.get_func_signature(*symbol)?; let type_args = args .iter() .map(|term| self.import_type_arg(*term)) .collect::, _>>()?; - self.static_edges.push((func_node, node_id)); + self.static_edges.push((*symbol, node_id)); let optype = OpType::Call(Call::try_new(func_sig, type_args, self.extensions)?); let node = self.make_node(node_id, optype, parent)?; @@ -452,19 +380,18 @@ impl<'a> Context<'a> { } model::Operation::LoadFunc { func } => { - let model::Term::ApplyFull { global: name, args } = self.get_term(func)? else { + let model::Term::ApplyFull { symbol, args } = self.get_term(func)? else { return Err(model::ModelError::TypeError(func).into()); }; - let func_node = self.resolve_global_ref(name)?; - let func_sig = self.get_func_signature(func_node)?; + let func_sig = self.get_func_signature(*symbol)?; let type_args = args .iter() .map(|term| self.import_type_arg(*term)) .collect::, _>>()?; - self.static_edges.push((func_node, node_id)); + self.static_edges.push((*symbol, node_id)); let optype = OpType::LoadFunction(LoadFunction::try_new( func_sig, @@ -485,16 +412,16 @@ impl<'a> Context<'a> { Ok(Some(node)) } - model::Operation::CustomFull { - operation: GlobalRef::Named(name), - } if name == OP_FUNC_CALL_INDIRECT => { - let signature = self.get_node_signature(node_id)?; - let optype = OpType::CallIndirect(CallIndirect { signature }); - let node = self.make_node(node_id, optype, parent)?; - Ok(Some(node)) - } - model::Operation::CustomFull { operation } => { + let name = self.get_global_name(operation)?; + + if name == OP_FUNC_CALL_INDIRECT { + let signature = self.get_node_signature(node_id)?; + let optype = OpType::CallIndirect(CallIndirect { signature }); + let node = self.make_node(node_id, optype, parent)?; + return Ok(Some(node)); + } + let signature = self.get_node_signature(node_id)?; let args = node_data .params @@ -502,7 +429,6 @@ impl<'a> Context<'a> { .map(|param| self.import_type_arg(*param)) .collect::, _>>()?; - let name = self.get_global_name(operation)?; let (extension, name) = self.import_custom_name(name)?; // TODO: Currently we do not have the description or any other metadata for @@ -533,7 +459,7 @@ impl<'a> Context<'a> { "custom operation with implicit parameters" )), - model::Operation::DefineAlias { decl, value } => self.with_local_socpe(|ctx| { + model::Operation::DefineAlias { decl, value } => { if !decl.params.is_empty() { return Err(error_unsupported!( "parameters or constraints in alias definition" @@ -542,14 +468,14 @@ impl<'a> Context<'a> { let optype = OpType::AliasDefn(AliasDefn { name: decl.name.to_smolstr(), - definition: ctx.import_type(value)?, + definition: self.import_type(value)?, }); - let node = ctx.make_node(node_id, optype, parent)?; + let node = self.make_node(node_id, optype, parent)?; Ok(Some(node)) - }), + } - model::Operation::DeclareAlias { decl } => self.with_local_socpe(|ctx| { + model::Operation::DeclareAlias { decl } => { if !decl.params.is_empty() { return Err(error_unsupported!( "parameters or constraints in alias declaration" @@ -561,9 +487,9 @@ impl<'a> Context<'a> { bound: TypeBound::Copyable, }); - let node = ctx.make_node(node_id, optype, parent)?; + let node = self.make_node(node_id, optype, parent)?; Ok(Some(node)) - }), + } model::Operation::Tag { tag } => { let signature = node_data @@ -582,6 +508,8 @@ impl<'a> Context<'a> { Ok(Some(node)) } + model::Operation::Import { .. } => Ok(None), + model::Operation::DeclareConstructor { .. } => Ok(None), model::Operation::DeclareOperation { .. } => Ok(None), } @@ -595,6 +523,11 @@ impl<'a> Context<'a> { ) -> Result<(), ImportError> { let region_data = self.get_region(region)?; + let prev_region = self.region_scope; + if region_data.scope.is_some() { + self.region_scope = region; + } + if region_data.kind != model::RegionKind::DataFlow { return Err(model::ModelError::InvalidRegions(node_id).into()); } @@ -627,6 +560,8 @@ impl<'a> Context<'a> { self.import_node(*child, node)?; } + self.region_scope = prev_region; + Ok(()) } @@ -759,6 +694,11 @@ impl<'a> Context<'a> { return Err(model::ModelError::InvalidRegions(node_id).into()); } + let prev_region = self.region_scope; + if region_data.scope.is_some() { + self.region_scope = region; + } + let (region_source, region_targets, _) = self.get_func_type( region_data .signature @@ -865,6 +805,8 @@ impl<'a> Context<'a> { self.record_links(exit, Direction::Incoming, region_data.targets); } + self.region_scope = prev_region; + Ok(()) } @@ -903,43 +845,43 @@ impl<'a> Context<'a> { fn import_poly_func_type( &mut self, + node: model::NodeId, decl: model::FuncDecl<'a>, in_scope: impl FnOnce(&mut Self, PolyFuncTypeBase) -> Result, ) -> Result { - self.with_local_socpe(|ctx| { - let mut imported_params = Vec::with_capacity(decl.params.len()); + let mut imported_params = Vec::with_capacity(decl.params.len()); - ctx.local_variables.extend( - decl.params - .iter() - .map(|param| (param.name, LocalVar::new(param.r#type))), - ); + for (index, param) in decl.params.iter().enumerate() { + self.local_vars + .insert(model::VarId(node, index as _), LocalVar::new(param.r#type)); + } - for constraint in decl.constraints { - match ctx.get_term(*constraint)? { - model::Term::NonLinearConstraint { term } => { - let model::Term::Var(var) = ctx.get_term(*term)? else { - return Err(error_unsupported!( - "constraint on term that is not a variable" - )); - }; - - let var = ctx.resolve_local_ref(var)?.0; - ctx.local_variables[var].bound = TypeBound::Copyable; - } - _ => return Err(error_unsupported!("constraint other than copy or discard")), + for constraint in decl.constraints { + match self.get_term(*constraint)? { + model::Term::NonLinearConstraint { term } => { + let model::Term::Var(var) = self.get_term(*term)? else { + return Err(error_unsupported!( + "constraint on term that is not a variable" + )); + }; + + self.local_vars + .get_mut(var) + .ok_or(model::ModelError::InvalidVar(*var))? + .bound = TypeBound::Copyable; } + _ => return Err(error_unsupported!("constraint other than copy or discard")), } + } - for (index, param) in decl.params.iter().enumerate() { - // NOTE: `PolyFuncType` only has explicit type parameters at present. - let bound = ctx.local_variables[index].bound; - imported_params.push(ctx.import_type_param(param.r#type, bound)?); - } + for (index, param) in decl.params.iter().enumerate() { + // NOTE: `PolyFuncType` only has explicit type parameters at present. + let bound = self.local_vars[&model::VarId(node, index as _)].bound; + imported_params.push(self.import_type_param(param.r#type, bound)?); + } - let body = ctx.import_func_type::(decl.signature)?; - in_scope(ctx, PolyFuncTypeBase::new(imported_params, body)) - }) + let body = self.import_func_type::(decl.signature)?; + in_scope(self, PolyFuncTypeBase::new(imported_params, body)) } /// Import a [`TypeParam`] from a term that represents a static type. @@ -955,7 +897,7 @@ impl<'a> Context<'a> { model::Term::StaticType => Err(error_unsupported!("`type` as `TypeParam`")), model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeParam`")), - model::Term::Var(_) => Err(error_unsupported!("type variable as `TypeParam`")), + model::Term::Var { .. } => Err(error_unsupported!("type variable as `TypeParam`")), model::Term::Apply { .. } => Err(error_unsupported!("custom type as `TypeParam`")), model::Term::ApplyFull { .. } => Err(error_unsupported!("custom type as `TypeParam`")), @@ -999,9 +941,12 @@ impl<'a> Context<'a> { } model::Term::Var(var) => { - let (index, var) = self.resolve_local_ref(var)?; - let decl = self.import_type_param(var.r#type, var.bound)?; - Ok(TypeArg::new_var_use(index, decl)) + let var_info = self + .local_vars + .get(var) + .ok_or(model::ModelError::InvalidVar(*var))?; + let decl = self.import_type_param(var_info.r#type, var_info.bound)?; + Ok(TypeArg::new_var_use(var.1 as _, decl)) } model::Term::List { .. } => { @@ -1056,9 +1001,8 @@ impl<'a> Context<'a> { match self.get_term(term_id)? { model::Term::Wildcard => return Err(error_uninferred!("wildcard")), - model::Term::Var(var) => { - let (index, _) = self.resolve_local_ref(var)?; - es.insert_type_var(index); + model::Term::Var(model::VarId(_, index)) => { + es.insert_type_var(*index as _); } model::Term::ExtSet { parts } => { @@ -1095,13 +1039,13 @@ impl<'a> Context<'a> { Err(error_uninferred!("application with implicit parameters")) } - model::Term::ApplyFull { global: name, args } => { + model::Term::ApplyFull { symbol, args } => { let args = args .iter() .map(|arg| self.import_type_arg(*arg)) .collect::, _>>()?; - let name = self.get_global_name(*name)?; + let name = self.get_global_name(*symbol)?; let (extension, id) = self.import_custom_name(name)?; let extension_ref = @@ -1123,10 +1067,8 @@ impl<'a> Context<'a> { ))) } - model::Term::Var(var) => { - // We pretend that all `TypeBound`s are copyable. - let (index, _) = self.resolve_local_ref(var)?; - Ok(TypeBase::new_var_use(index, TypeBound::Copyable)) + model::Term::Var(model::VarId(_, index)) => { + Ok(TypeBase::new_var_use(*index as _, TypeBound::Copyable)) } model::Term::FuncType { .. } => { @@ -1259,9 +1201,8 @@ impl<'a> Context<'a> { } } } - model::Term::Var(var) => { - let (index, _) = ctx.resolve_local_ref(var)?; - let var = RV::try_from_rv(RowVariable(index, TypeBound::Any)) + model::Term::Var(model::VarId(_, index)) => { + let var = RV::try_from_rv(RowVariable(*index as _, TypeBound::Any)) .map_err(|_| model::ModelError::TypeError(term_id))?; types.push(TypeBase::new(TypeEnum::RowVar(var))); } @@ -1302,13 +1243,14 @@ impl<'a> Context<'a> { term_id: model::TermId, ) -> Result { let (global, args) = match self.get_term(term_id)? { - model::Term::Apply { global, args } | model::Term::ApplyFull { global, args } => { - (global, args) + model::Term::Apply { symbol, args } | model::Term::ApplyFull { symbol, args } => { + (symbol, args) } _ => return Err(model::ModelError::TypeError(term_id).into()), }; - if global != &GlobalRef::Named(TERM_JSON) { + let global = self.get_global_name(*global)?; + if global != TERM_JSON { return Err(model::ModelError::TypeError(term_id).into()); } @@ -1327,51 +1269,6 @@ impl<'a> Context<'a> { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -enum NamedItem { - FuncDecl(model::NodeId), - FuncDefn(model::NodeId), - CtrDecl(model::NodeId), - OperationDecl(model::NodeId), -} - -struct Names<'a> { - items: FxHashMap<&'a str, NamedItem>, -} - -impl<'a> Names<'a> { - pub fn new(module: &model::Module<'a>) -> Result { - let mut items = FxHashMap::default(); - - for (node_id, node_data) in module.nodes.iter().enumerate() { - let node_id = model::NodeId(node_id as _); - - let item = match node_data.operation { - model::Operation::DefineFunc { decl } => { - Some((decl.name, NamedItem::FuncDecl(node_id))) - } - model::Operation::DeclareFunc { decl } => { - Some((decl.name, NamedItem::FuncDefn(node_id))) - } - model::Operation::DeclareConstructor { decl } => { - Some((decl.name, NamedItem::CtrDecl(node_id))) - } - model::Operation::DeclareOperation { decl } => { - Some((decl.name, NamedItem::OperationDecl(node_id))) - } - _ => None, - }; - - if let Some((name, item)) = item { - // TODO: Deal with duplicates - items.insert(name, item); - } - } - - Ok(Self { items }) - } -} - /// Information about a local variable. #[derive(Debug, Clone, Copy)] struct LocalVar { diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs index 7ae4010f4..35306e27d 100644 --- a/hugr-core/tests/model.rs +++ b/hugr-core/tests/model.rs @@ -7,6 +7,7 @@ use hugr_model::v0 as model; fn roundtrip(source: &str) -> String { let bump = bumpalo::Bump::new(); let parsed_model = model::text::parse(source, &bump).unwrap(); + println!("{:#?}", parsed_model); let imported_hugr = import_hugr(&parsed_model.module, &std_reg()).unwrap(); let exported_model = export_hugr(&imported_hugr, &bump); model::text::print_to_string(&exported_model, 80).unwrap() diff --git a/hugr-core/tests/snapshots/model__roundtrip_add.snap b/hugr-core/tests/snapshots/model__roundtrip_add.snap index b7de139fe..7ffec5ef9 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_add.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_add.snap @@ -1,9 +1,13 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"fixtures/model-add.edn\"))" +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-add.edn\"))" --- (hugr 0) +(import arithmetic.int.iadd) + +(import arithmetic.int.types.int) + (define-func example.add [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] diff --git a/hugr-core/tests/snapshots/model__roundtrip_alias.snap b/hugr-core/tests/snapshots/model__roundtrip_alias.snap index c279c5d6a..27fdd4740 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_alias.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_alias.snap @@ -1,9 +1,11 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"fixtures/model-alias.edn\"))" +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-alias.edn\"))" --- (hugr 0) +(import arithmetic.int.types.int) + (declare-alias local.float type) (define-alias local.int type (@ arithmetic.int.types.int)) diff --git a/hugr-core/tests/snapshots/model__roundtrip_call.snap b/hugr-core/tests/snapshots/model__roundtrip_call.snap index 5ddc4eb32..2b37b5a20 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_call.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_call.snap @@ -4,6 +4,10 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call --- (hugr 0) +(import prelude.json) + +(import arithmetic.int.types.int) + (declare-func example.callee (forall ?0 ext-set) [(@ arithmetic.int.types.int)] diff --git a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap index 41a8f0d62..e39f0d37d 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap @@ -13,14 +13,14 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg. (cfg [%0] [%1] (signature (fn [?0] [?0] (ext))) (cfg - [%2] [%8] + [%4] [%8] (signature (fn [?0] [?0] (ext))) - (block [%2] [%5] + (block [%4] [%5] (signature (fn [(ctrl [?0])] [(ctrl [?0])] (ext))) (dfg - [%3] [%4] + [%2] [%3] (signature (fn [?0] [(adt [[?0]])] (ext))) - (tag 0 [%3] [%4] (signature (fn [?0] [(adt [[?0]])] (ext)))))) + (tag 0 [%2] [%3] (signature (fn [?0] [(adt [[?0]])] (ext)))))) (block [%5] [%8] (signature (fn [(ctrl [?0])] [(ctrl [?0])] (ext))) (dfg diff --git a/hugr-core/tests/snapshots/model__roundtrip_cond.snap b/hugr-core/tests/snapshots/model__roundtrip_cond.snap index fe55e965f..92ab0cb4d 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cond.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cond.snap @@ -1,9 +1,13 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"fixtures/model-cond.edn\"))" +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cond.edn\"))" --- (hugr 0) +(import arithmetic.int.types.int) + +(import arithmetic.int.ineg) + (define-func example.cond [(adt [[] []]) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] diff --git a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap index 291c2de48..d7cb2bf01 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap @@ -4,6 +4,8 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons --- (hugr 0) +(import prelude.Array) + (declare-func array.replicate (forall ?0 type) (forall ?1 nat) diff --git a/hugr-model/capnp/hugr-v0.capnp b/hugr-model/capnp/hugr-v0.capnp index 366de92eb..b3bb1f0f2 100644 --- a/hugr-model/capnp/hugr-v0.capnp +++ b/hugr-model/capnp/hugr-v0.capnp @@ -15,6 +15,9 @@ using NodeId = UInt32; # The id of a `Link`. using LinkId = UInt32; +# The index of a `Link`. +using LinkIndex = UInt32; + struct Module { root @0 :RegionId; nodes @1 :List(Node); @@ -24,8 +27,8 @@ struct Module { struct Node { operation @0 :Operation; - inputs @1 :List(LinkRef); - outputs @2 :List(LinkRef); + inputs @1 :List(LinkIndex); + outputs @2 :List(LinkIndex); params @3 :List(TermId); regions @4 :List(RegionId); meta @5 :List(MetaItem); @@ -42,8 +45,8 @@ struct Operation { funcDecl @5 :FuncDecl; aliasDefn @6 :AliasDefn; aliasDecl @7 :AliasDecl; - custom @8 :GlobalRef; - customFull @9 :GlobalRef; + custom @8 :NodeId; + customFull @9 :NodeId; tag @10 :UInt16; tailLoop @11 :Void; conditional @12 :Void; @@ -51,6 +54,7 @@ struct Operation { loadFunc @14 :TermId; constructorDecl @15 :ConstructorDecl; operationDecl @16 :OperationDecl; + import @17 :Text; } struct FuncDefn { @@ -97,13 +101,22 @@ struct Operation { struct Region { kind @0 :RegionKind; - sources @1 :List(LinkRef); - targets @2 :List(LinkRef); + sources @1 :List(LinkIndex); + targets @2 :List(LinkIndex); children @3 :List(NodeId); meta @4 :List(MetaItem); signature @5 :OptionalTermId; + scope @6 :RegionScope; +} + +struct RegionScope { + links @0 :UInt32; + ports @1 :UInt32; } +# Either `0` for an open scope, or the number of links in the closed scope incremented by `1`. +using LinkScope = UInt32; + enum RegionKind { dataFlow @0; controlFlow @1; @@ -115,37 +128,16 @@ struct MetaItem { value @1 :UInt32; } -struct LinkRef { - union { - id @0 :LinkId; - named @1 :Text; - } -} - -struct GlobalRef { - union { - node @0 :NodeId; - named @1 :Text; - } -} - -struct LocalRef { - union { - direct :group { - index @0 :UInt16; - node @1 :NodeId; - } - named @2 :Text; - } -} - struct Term { union { wildcard @0 :Void; runtimeType @1 :Void; staticType @2 :Void; constraint @3 :Void; - variable @4 :LocalRef; + variable :group { + variableNode @4 :NodeId; + variableIndex @21 :UInt16; + } apply @5 :Apply; applyFull @6 :ApplyFull; quote @7 :TermId; @@ -165,12 +157,12 @@ struct Term { } struct Apply { - global @0 :GlobalRef; + symbol @0 :NodeId; args @1 :List(TermId); } struct ApplyFull { - global @0 :GlobalRef; + symbol @0 :NodeId; args @1 :List(TermId); } diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index 2dfe67efc..b14ca4482 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -71,8 +71,8 @@ fn read_module<'a>( fn read_node<'a>(bump: &'a Bump, reader: hugr_capnp::node::Reader) -> ReadResult> { let operation = read_operation(bump, reader.get_operation()?)?; - let inputs = read_list!(bump, reader, get_inputs, read_link_ref); - let outputs = read_list!(bump, reader, get_outputs, read_link_ref); + let inputs = read_scalar_list!(bump, reader, get_inputs, model::LinkIndex); + let outputs = read_scalar_list!(bump, reader, get_outputs, model::LinkIndex); let params = read_scalar_list!(bump, reader, get_params, model::TermId); let regions = read_scalar_list!(bump, reader, get_regions, model::RegionId); let meta = read_list!(bump, reader, get_meta, read_meta_item); @@ -89,43 +89,6 @@ fn read_node<'a>(bump: &'a Bump, reader: hugr_capnp::node::Reader) -> ReadResult }) } -fn read_local_ref<'a>( - bump: &'a Bump, - reader: hugr_capnp::local_ref::Reader, -) -> ReadResult> { - use hugr_capnp::local_ref::Which; - Ok(match reader.which()? { - Which::Direct(reader) => { - let index = reader.get_index(); - let node = model::NodeId(reader.get_node()); - model::LocalRef::Index(node, index) - } - Which::Named(name) => model::LocalRef::Named(bump.alloc_str(name?.to_str()?)), - }) -} - -fn read_global_ref<'a>( - bump: &'a Bump, - reader: hugr_capnp::global_ref::Reader, -) -> ReadResult> { - use hugr_capnp::global_ref::Which; - Ok(match reader.which()? { - Which::Node(node) => model::GlobalRef::Direct(model::NodeId(node)), - Which::Named(name) => model::GlobalRef::Named(bump.alloc_str(name?.to_str()?)), - }) -} - -fn read_link_ref<'a>( - bump: &'a Bump, - reader: hugr_capnp::link_ref::Reader, -) -> ReadResult> { - use hugr_capnp::link_ref::Which; - Ok(match reader.which()? { - Which::Id(id) => model::LinkRef::Id(model::LinkId(id)), - Which::Named(name) => model::LinkRef::Named(bump.alloc_str(name?.to_str()?)), - }) -} - fn read_operation<'a>( bump: &'a Bump, reader: hugr_capnp::operation::Reader, @@ -217,11 +180,11 @@ fn read_operation<'a>( }); model::Operation::DeclareOperation { decl } } - Which::Custom(name) => model::Operation::Custom { - operation: read_global_ref(bump, name?)?, + Which::Custom(operation) => model::Operation::Custom { + operation: model::NodeId(operation), }, - Which::CustomFull(name) => model::Operation::CustomFull { - operation: read_global_ref(bump, name?)?, + Which::CustomFull(operation) => model::Operation::CustomFull { + operation: model::NodeId(operation), }, Which::Tag(tag) => model::Operation::Tag { tag }, Which::TailLoop(()) => model::Operation::TailLoop, @@ -232,6 +195,9 @@ fn read_operation<'a>( Which::LoadFunc(func) => model::Operation::LoadFunc { func: model::TermId(func), }, + Which::Import(name) => model::Operation::Import { + name: bump.alloc_str(name?.to_str()?), + }, }) } @@ -245,12 +211,18 @@ fn read_region<'a>( hugr_capnp::RegionKind::Module => model::RegionKind::Module, }; - let sources = read_list!(bump, reader, get_sources, read_link_ref); - let targets = read_list!(bump, reader, get_targets, read_link_ref); + let sources = read_scalar_list!(bump, reader, get_sources, model::LinkIndex); + let targets = read_scalar_list!(bump, reader, get_targets, model::LinkIndex); let children = read_scalar_list!(bump, reader, get_children, model::NodeId); let meta = read_list!(bump, reader, get_meta, read_meta_item); let signature = reader.get_signature().checked_sub(1).map(model::TermId); + let scope = if reader.has_scope() { + Some(read_region_scope(reader.get_scope()?)?) + } else { + None + }; + Ok(model::Region { kind, sources, @@ -258,9 +230,16 @@ fn read_region<'a>( children, meta, signature, + scope, }) } +fn read_region_scope(reader: hugr_capnp::region_scope::Reader) -> ReadResult { + let links = reader.get_links(); + let ports = reader.get_ports(); + Ok(model::RegionScope { links, ports }) +} + fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult> { use hugr_capnp::term::Which; Ok(match reader.which()? { @@ -274,20 +253,25 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult Which::NatType(()) => model::Term::NatType, Which::ExtSetType(()) => model::Term::ExtSetType, Which::ControlType(()) => model::Term::ControlType, - Which::Variable(local_ref) => model::Term::Var(read_local_ref(bump, local_ref?)?), + + Which::Variable(reader) => { + let node = model::NodeId(reader.get_variable_node()); + let index = reader.get_variable_index(); + model::Term::Var(model::VarId(node, index)) + } Which::Apply(reader) => { let reader = reader?; - let global = read_global_ref(bump, reader.get_global()?)?; + let symbol = model::NodeId(reader.get_symbol()); let args = read_scalar_list!(bump, reader, get_args, model::TermId); - model::Term::Apply { global, args } + model::Term::Apply { symbol, args } } Which::ApplyFull(reader) => { let reader = reader?; - let global = read_global_ref(bump, reader.get_global()?)?; + let symbol = model::NodeId(reader.get_symbol()); let args = read_scalar_list!(bump, reader, get_args, model::TermId); - model::Term::ApplyFull { global, args } + model::Term::ApplyFull { symbol, args } } Which::Quote(r#type) => model::Term::Quote { diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index aa377e2ec..ea495db54 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -31,8 +31,8 @@ fn write_module(mut builder: hugr_capnp::module::Builder, module: &model::Module fn write_node(mut builder: hugr_capnp::node::Builder, node: &model::Node) { write_operation(builder.reborrow().init_operation(), &node.operation); - write_list!(builder, init_inputs, write_link_ref, node.inputs); - write_list!(builder, init_outputs, write_link_ref, node.outputs); + let _ = builder.set_inputs(model::LinkIndex::unwrap_slice(node.inputs)); + let _ = builder.set_outputs(model::LinkIndex::unwrap_slice(node.outputs)); write_list!(builder, init_meta, write_meta_item, node.meta); let _ = builder.set_params(model::TermId::unwrap_slice(node.params)); let _ = builder.set_regions(model::RegionId::unwrap_slice(node.regions)); @@ -47,11 +47,9 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode model::Operation::TailLoop => builder.set_tail_loop(()), model::Operation::Conditional => builder.set_conditional(()), model::Operation::Tag { tag } => builder.set_tag(*tag), - model::Operation::Custom { operation } => { - write_global_ref(builder.init_custom(), operation) - } + model::Operation::Custom { operation } => builder.set_custom(operation.0), model::Operation::CustomFull { operation } => { - write_global_ref(builder.init_custom_full(), operation) + builder.set_custom_full(operation.0); } model::Operation::CallFunc { func } => builder.set_call_func(func.0), model::Operation::LoadFunc { func } => builder.set_load_func(func.0), @@ -100,6 +98,10 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode builder.set_type(decl.r#type.0); } + model::Operation::Import { name } => { + builder.set_import(*name); + } + model::Operation::Invalid => builder.set_invalid(()), } } @@ -113,31 +115,6 @@ fn write_param(mut builder: hugr_capnp::param::Builder, param: &model::Param) { }); } -fn write_global_ref(mut builder: hugr_capnp::global_ref::Builder, global_ref: &model::GlobalRef) { - match global_ref { - model::GlobalRef::Direct(node) => builder.set_node(node.0), - model::GlobalRef::Named(name) => builder.set_named(name), - } -} - -fn write_link_ref(mut builder: hugr_capnp::link_ref::Builder, link_ref: &model::LinkRef) { - match link_ref { - model::LinkRef::Id(id) => builder.set_id(id.0), - model::LinkRef::Named(name) => builder.set_named(name), - } -} - -fn write_local_ref(mut builder: hugr_capnp::local_ref::Builder, local_ref: &model::LocalRef) { - match local_ref { - model::LocalRef::Index(node, index) => { - let mut builder = builder.init_direct(); - builder.set_node(node.0); - builder.set_index(*index); - } - model::LocalRef::Named(name) => builder.set_named(name), - } -} - fn write_meta_item(mut builder: hugr_capnp::meta_item::Builder, meta_item: &model::MetaItem) { builder.set_name(meta_item.name); builder.set_value(meta_item.value.0) @@ -150,11 +127,20 @@ fn write_region(mut builder: hugr_capnp::region::Builder, region: &model::Region model::RegionKind::Module => hugr_capnp::RegionKind::Module, }); - write_list!(builder, init_sources, write_link_ref, region.sources); - write_list!(builder, init_targets, write_link_ref, region.targets); + let _ = builder.set_sources(model::LinkIndex::unwrap_slice(region.sources)); + let _ = builder.set_targets(model::LinkIndex::unwrap_slice(region.targets)); let _ = builder.set_children(model::NodeId::unwrap_slice(region.children)); write_list!(builder, init_meta, write_meta_item, region.meta); builder.set_signature(region.signature.map_or(0, |t| t.0 + 1)); + + if let Some(scope) = ®ion.scope { + write_region_scope(builder.init_scope(), scope); + } +} + +fn write_region_scope(mut builder: hugr_capnp::region_scope::Builder, scope: &model::RegionScope) { + builder.set_links(scope.links); + builder.set_ports(scope.ports); } fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { @@ -163,7 +149,11 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { model::Term::Type => builder.set_runtime_type(()), model::Term::StaticType => builder.set_static_type(()), model::Term::Constraint => builder.set_constraint(()), - model::Term::Var(local_ref) => write_local_ref(builder.init_variable(), local_ref), + model::Term::Var(model::VarId(node, index)) => { + let mut builder = builder.init_variable(); + builder.set_variable_node(node.0); + builder.set_variable_index(*index); + } model::Term::ListType { item_type } => builder.set_list_type(item_type.0), model::Term::Str(value) => builder.set_string(value), model::Term::StrType => builder.set_string_type(()), @@ -175,15 +165,15 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { model::Term::Control { values } => builder.set_control(values.0), model::Term::ControlType => builder.set_control_type(()), - model::Term::Apply { global, args } => { + model::Term::Apply { symbol, args } => { let mut builder = builder.init_apply(); - write_global_ref(builder.reborrow().init_global(), global); + builder.set_symbol(symbol.0); let _ = builder.set_args(model::TermId::unwrap_slice(args)); } - model::Term::ApplyFull { global, args } => { + model::Term::ApplyFull { symbol, args } => { let mut builder = builder.init_apply_full(); - write_global_ref(builder.reborrow().init_global(), global); + builder.set_symbol(symbol.0); let _ = builder.set_args(model::TermId::unwrap_slice(args)); } diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index 2b0dc1eaf..c13d8a589 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -91,6 +91,7 @@ use smol_str::SmolStr; use thiserror::Error; pub mod binary; +pub mod scope; pub mod text; macro_rules! define_index { @@ -132,7 +133,7 @@ macro_rules! define_index { } define_index! { - /// Index of a node in a hugr graph. + /// Id of a node in a hugr graph. #[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] pub struct NodeId(pub u32); } @@ -140,21 +141,31 @@ define_index! { define_index! { /// Index of a link in a hugr graph. #[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] - pub struct LinkId(pub u32); + pub struct LinkIndex(pub u32); } define_index! { - /// Index of a region in a hugr graph. + /// Id of a region in a hugr graph. #[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] pub struct RegionId(pub u32); } define_index! { - /// Index of a term in a hugr graph. + /// Id of a term in a hugr graph. #[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] pub struct TermId(pub u32); } +/// The id of a link consisting of its region and the link index. +#[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[display("{_0}#{_1}")] +pub struct LinkId(pub RegionId, pub LinkIndex); + +/// The id of a variable consisting of its node and the variable index. +#[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[display("{_0}#{_1}")] +pub struct VarId(pub NodeId, pub VarIndex); + /// A module consisting of a hugr graph together with terms. #[derive(Debug, Clone, Default, PartialEq, Eq, Hash)] pub struct Module<'a> { @@ -233,9 +244,9 @@ pub struct Node<'a> { /// The operation that the node performs. pub operation: Operation<'a>, /// The input ports of the node. - pub inputs: &'a [LinkRef<'a>], + pub inputs: &'a [LinkIndex], /// The output ports of the node. - pub outputs: &'a [LinkRef<'a>], + pub outputs: &'a [LinkIndex], /// The parameters of the node. pub params: &'a [TermId], /// The regions of the node. @@ -290,8 +301,8 @@ pub enum Operation<'a> { /// becomes known by resolving the reference, the node can be transformed into a [`Operation::CustomFull`] /// by inferring terms for the implicit parameters or at least filling them in with a wildcard term. Custom { - /// The name of the custom operation. - operation: GlobalRef<'a>, + /// The symbol of the custom operation. + operation: NodeId, }, /// Custom operation with full parameters. /// @@ -299,8 +310,8 @@ pub enum Operation<'a> { /// Since this can be tedious to write, the [`Operation::Custom`] variant can be used to indicate that /// the implicit parameters should be inferred. CustomFull { - /// The name of the custom operation. - operation: GlobalRef<'a>, + /// The symbol of the custom operation. + operation: NodeId, }, /// Alias definitions. DefineAlias { @@ -358,17 +369,39 @@ pub enum Operation<'a> { /// The declaration of the operation. decl: &'a OperationDecl<'a>, }, + + /// Import a symbol. + Import { + /// The name of the symbol to be imported. + name: &'a str, + }, +} + +impl<'a> Operation<'a> { + /// Returns the symbol introduced by the operation, if any. + pub fn symbol(&self) -> Option<&'a str> { + match self { + Operation::DefineFunc { decl } => Some(decl.name), + Operation::DeclareFunc { decl } => Some(decl.name), + Operation::DefineAlias { decl, .. } => Some(decl.name), + Operation::DeclareAlias { decl } => Some(decl.name), + Operation::DeclareConstructor { decl } => Some(decl.name), + Operation::DeclareOperation { decl } => Some(decl.name), + Operation::Import { name } => Some(name), + _ => None, + } + } } /// A region in the hugr. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] pub struct Region<'a> { /// The kind of the region. See [`RegionKind`] for details. pub kind: RegionKind, /// The source ports of the region. - pub sources: &'a [LinkRef<'a>], + pub sources: &'a [LinkIndex], /// The target ports of the region. - pub targets: &'a [LinkRef<'a>], + pub targets: &'a [LinkIndex], /// The nodes in the region. The order of the nodes is not significant. pub children: &'a [NodeId], /// The metadata attached to the region. @@ -377,12 +410,34 @@ pub struct Region<'a> { /// /// Can be `None` to indicate that the region signature should be inferred. pub signature: Option, + /// Information about the scope defined by this region, if the region is closed. + pub scope: Option, +} + +/// Information about the scope defined by a closed region. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct RegionScope { + /// The number of links in the scope. + pub links: u32, + /// The number of ports in the scope. + pub ports: u32, +} + +/// The link scope of a region. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] +pub enum LinkScope { + /// The region is open and shares its scope of links with its parent region. + #[default] + Open, + /// The region is closed and has its own scope with the given number of links. + Closed(u32), } /// The kind of a region. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] pub enum RegionKind { /// Data flow region. + #[default] DataFlow = 0, /// Control flow region. ControlFlow = 1, @@ -449,63 +504,8 @@ pub struct MetaItem<'a> { pub value: TermId, } -/// A reference to a global variable. -/// -/// Global variables are defined in nodes. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum GlobalRef<'a> { - /// Reference to the global that is defined by the given node. - Direct(NodeId), - /// Reference to the global with the given name. - Named(&'a str), -} - -impl std::fmt::Display for GlobalRef<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - GlobalRef::Direct(id) => write!(f, ":{}", id.index()), - GlobalRef::Named(name) => write!(f, "{}", name), - } - } -} - -/// A reference to a local variable. -/// -/// Local variables are defined as parameters to nodes. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum LocalRef<'a> { - /// Reference to the local variable by its parameter index and its defining node. - Index(NodeId, u16), - /// Reference to the local variable by its name. - Named(&'a str), -} - -impl std::fmt::Display for LocalRef<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - LocalRef::Index(node, index) => write!(f, "?:{}:{}", node.index(), index), - LocalRef::Named(name) => write!(f, "?{}", name), - } - } -} - -/// A reference to a link. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum LinkRef<'a> { - /// Reference to the link by its id. - Id(LinkId), - /// Reference to the link by its name. - Named(&'a str), -} - -impl std::fmt::Display for LinkRef<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - LinkRef::Id(id) => write!(f, "%:{})", id.index()), - LinkRef::Named(name) => write!(f, "%{}", name), - } - } -} +/// An index of a variable within a node's parameter list. +pub type VarIndex = u16; /// A term in the compile time meta language. #[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] @@ -530,7 +530,7 @@ pub enum Term<'a> { Constraint, /// A local variable. - Var(LocalRef<'a>), + Var(VarId), /// A symbolic function application. /// @@ -540,8 +540,8 @@ pub enum Term<'a> { /// /// `(GLOBAL ARG-0 ... ARG-n)` Apply { - /// Reference to the global declaration to apply. - global: GlobalRef<'a>, + /// Reference to the symbol to apply. + symbol: NodeId, /// Arguments to the function, covering only the explicit parameters. args: &'a [TermId], }, @@ -553,8 +553,8 @@ pub enum Term<'a> { /// /// `(@GLOBAL ARG-0 ... ARG-n)` ApplyFull { - /// Reference to the global declaration to apply. - global: GlobalRef<'a>, + /// Reference to the symbol to apply. + symbol: NodeId, /// Arguments to the function, covering both implicit and explicit parameters. args: &'a [TermId], }, @@ -718,13 +718,12 @@ pub enum ModelError { /// There is a reference to a region that does not exist. #[error("region not found: {0}")] RegionNotFound(RegionId), - /// There is a local reference that does not resolve. - #[error("local variable invalid: {0}")] - InvalidLocal(String), - /// There is a global reference that does not resolve to a node - /// that defines a global variable. - #[error("global variable invalid: {0}")] - InvalidGlobal(String), + /// Invalid variable reference. + #[error("variable {0} invalid")] + InvalidVar(VarId), + /// Invalid symbol reference. + #[error("symbol reference {0} invalid")] + InvalidSymbol(NodeId), /// The model contains an operation in a place where it is not allowed. #[error("unexpected operation on node: {0}")] UnexpectedOperation(NodeId), diff --git a/hugr-model/src/v0/scope/link.rs b/hugr-model/src/v0/scope/link.rs new file mode 100644 index 000000000..b93ef6ac3 --- /dev/null +++ b/hugr-model/src/v0/scope/link.rs @@ -0,0 +1,113 @@ +use std::hash::{BuildHasherDefault, Hash}; + +use fxhash::FxHasher; +use indexmap::IndexSet; + +use crate::v0::{LinkIndex, RegionId}; + +type FxIndexSet = IndexSet>; + +/// Table for tracking links between ports. +/// +/// Two ports are connected when they share the same link. Links are named and +/// scoped via isolated regions. The names of links must be unique within a +/// single isolated region. Links from one isolated region are not visible in +/// another. Links do not have a unique point of declaration. +/// +/// # Examples +/// +/// ``` +/// # pub use hugr_model::v0::RegionId; +/// # pub use hugr_model::v0::scope::LinkTable; +/// let mut links = LinkTable::new(); +/// links.enter(RegionId(0)); +/// let foo_0 = links.use_link("foo"); +/// let bar_0 = links.use_link("bar"); +/// assert_eq!(foo_0, links.use_link("foo")); +/// assert_eq!(bar_0, links.use_link("bar")); +/// let (num_links, num_ports) = links.exit(); +/// assert_eq!(num_links, 2); +/// assert_eq!(num_ports, 4); +/// ``` +#[derive(Debug, Clone)] +pub struct LinkTable { + links: FxIndexSet<(RegionId, K)>, + scopes: Vec, +} + +impl LinkTable +where + K: Copy + Eq + Hash, +{ + /// Create a new empty link table. + pub fn new() -> Self { + Self { + links: FxIndexSet::default(), + scopes: Vec::new(), + } + } + + /// Enter a new scope for the given region. + pub fn enter(&mut self, region: RegionId) { + self.scopes.push(LinkScope { + link_stack: self.links.len(), + link_count: 0, + port_count: 0, + region, + }); + } + + /// Exit a previously entered scope, returning the number of links and ports in the scope. + pub fn exit(&mut self) -> (u32, u32) { + let scope = self.scopes.pop().unwrap(); + self.links.drain(scope.link_stack..); + debug_assert_eq!(self.links.len(), scope.link_stack); + (scope.link_count, scope.port_count) + } + + /// Resolve a link key to a link index, adding one more port to the current scope. + /// + /// If the key has not been used in the current scope before, it will be added to the link table. + /// + /// # Panics + /// + /// Panics if there are no open scopes. + pub fn use_link(&mut self, key: K) -> LinkIndex { + let scope = self.scopes.last_mut().unwrap(); + let (map_index, inserted) = self.links.insert_full((scope.region, key)); + + if inserted { + scope.link_count += 1; + } + + scope.port_count += 1; + LinkIndex::new(map_index - scope.link_stack) + } + + /// Reset the link table to an empty state while preserving allocated memory. + pub fn clear(&mut self) { + self.links.clear(); + self.scopes.clear(); + } +} + +impl Default for LinkTable +where + K: Copy + Eq + Hash, +{ + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Clone)] +struct LinkScope { + /// The length of `LinkTable::links` when the scope was opened. + link_stack: usize, + /// The number of links in this scope. + link_count: u32, + /// The number of ports in this scope. + port_count: u32, + /// The region that introduces this scope. + region: RegionId, +} diff --git a/hugr-model/src/v0/scope/mod.rs b/hugr-model/src/v0/scope/mod.rs new file mode 100644 index 000000000..546d97d61 --- /dev/null +++ b/hugr-model/src/v0/scope/mod.rs @@ -0,0 +1,8 @@ +//! Utilities for working with scoped symbols, variables and links. +mod link; +mod symbol; +mod vars; + +pub use link::LinkTable; +pub use symbol::{DuplicateSymbolError, SymbolTable, UnknownSymbolError}; +pub use vars::{DuplicateVarError, UnknownVarError, VarTable}; diff --git a/hugr-model/src/v0/scope/symbol.rs b/hugr-model/src/v0/scope/symbol.rs new file mode 100644 index 000000000..5c54d930a --- /dev/null +++ b/hugr-model/src/v0/scope/symbol.rs @@ -0,0 +1,208 @@ +use std::{borrow::Cow, hash::BuildHasherDefault}; + +use fxhash::FxHasher; +use indexmap::IndexMap; +use thiserror::Error; + +use crate::v0::{NodeId, RegionId}; + +type FxIndexMap = IndexMap>; + +/// Symbol binding table that keeps track of symbol resolution and scoping. +/// +/// Nodes may introduce a symbol so that other parts of the IR can refer to the +/// node. Symbols have an associated name and are scoped via regions. A symbol +/// can shadow another symbol with the same name from an outer region, but +/// within any single region each symbol name must be unique. +/// +/// When a symbol is referred to directly by the id of the node, the symbol must +/// be in scope at the point of reference as if the reference was by name. This +/// guarantees that transformations between directly indexed and named formats +/// are always valid. +/// +/// # Examples +/// +/// ``` +/// # pub use hugr_model::v0::{NodeId, RegionId}; +/// # pub use hugr_model::v0::scope::SymbolTable; +/// let mut symbols = SymbolTable::new(); +/// symbols.enter(RegionId(0)); +/// symbols.insert("foo", NodeId(0)).unwrap(); +/// assert_eq!(symbols.resolve("foo").unwrap(), NodeId(0)); +/// symbols.enter(RegionId(1)); +/// assert_eq!(symbols.resolve("foo").unwrap(), NodeId(0)); +/// symbols.insert("foo", NodeId(1)).unwrap(); +/// assert_eq!(symbols.resolve("foo").unwrap(), NodeId(1)); +/// assert!(!symbols.is_visible(NodeId(0))); +/// symbols.exit(); +/// assert_eq!(symbols.resolve("foo").unwrap(), NodeId(0)); +/// assert!(symbols.is_visible(NodeId(0))); +/// assert!(!symbols.is_visible(NodeId(1))); +/// ``` +#[derive(Debug, Clone)] +pub struct SymbolTable<'a> { + symbols: FxIndexMap<&'a str, BindingIndex>, + bindings: FxIndexMap, + scopes: FxIndexMap, +} + +impl<'a> SymbolTable<'a> { + /// Create a new symbol table. + pub fn new() -> Self { + Self { + symbols: FxIndexMap::default(), + bindings: FxIndexMap::default(), + scopes: FxIndexMap::default(), + } + } + + /// Enter a new scope for the given region. + pub fn enter(&mut self, region: RegionId) { + self.scopes.insert( + region, + Scope { + binding_stack: self.bindings.len(), + }, + ); + } + + /// Exit a previously entered scope. + /// + /// # Panics + /// + /// Panics if there are no remaining open scopes. + pub fn exit(&mut self) { + let (_, scope) = self.scopes.pop().unwrap(); + + for _ in scope.binding_stack..self.bindings.len() { + let (_, binding) = self.bindings.pop().unwrap(); + + if let Some(shadows) = binding.shadows { + self.symbols[binding.symbol_index] = shadows; + } else { + let last = self.symbols.pop(); + debug_assert_eq!(last.unwrap().1, self.bindings.len()); + } + } + } + + /// Insert a new symbol into the current scope. + /// + /// # Errors + /// + /// Returns an error if the symbol is already defined in the current scope. + /// In the case of an error the table remains unchanged. + /// + /// # Panics + /// + /// Panics if there is no current scope. + pub fn insert(&mut self, name: &'a str, node: NodeId) -> Result<(), DuplicateSymbolError> { + let scope_depth = self.scopes.len() as u16 - 1; + let (symbol_index, shadowed) = self.symbols.insert_full(name, self.bindings.len()); + + if let Some(shadowed) = shadowed { + let (shadowed_node, shadowed_binding) = self.bindings.get_index(shadowed).unwrap(); + if shadowed_binding.scope_depth == scope_depth { + self.symbols.insert(name, shadowed); + return Err(DuplicateSymbolError(name.into(), node, *shadowed_node)); + } + } + + self.bindings.insert( + node, + Binding { + scope_depth, + shadows: shadowed, + symbol_index, + }, + ); + + Ok(()) + } + + /// Check whether a symbol is currently visible in the current scope. + pub fn is_visible(&self, node: NodeId) -> bool { + let Some(binding) = self.bindings.get(&node) else { + return false; + }; + + // Check that the symbol has not been shadowed at this point. + self.symbols[binding.symbol_index] == binding.symbol_index + } + + /// Tries to resolve a symbol name in the current scope. + pub fn resolve(&self, name: &'a str) -> Result { + let index = *self + .symbols + .get(name) + .ok_or(UnknownSymbolError(name.into()))?; + + // NOTE: The unwrap is safe because the `symbols` map + // points to valid indices in the `bindings` map. + let (node, _) = self.bindings.get_index(index).unwrap(); + Ok(*node) + } + + /// Returns the depth of the given region, if it corresponds to a currently open scope. + pub fn region_to_depth(&self, region: RegionId) -> Option { + Some(self.scopes.get_index_of(®ion)? as _) + } + + /// Returns the region corresponding to the scope at the given depth. + pub fn depth_to_region(&self, depth: ScopeDepth) -> Option { + let (region, _) = self.scopes.get_index(depth as _)?; + Some(*region) + } + + /// Resets the symbol table to its initial state while maintaining its + /// allocated memory. + pub fn clear(&mut self) { + self.symbols.clear(); + self.bindings.clear(); + self.scopes.clear(); + } +} + +impl Default for SymbolTable<'_> { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Clone, Copy)] +struct Binding { + /// The depth of the scope in which this binding is defined. + scope_depth: ScopeDepth, + + /// The index of the binding that is shadowed by this one, if any. + shadows: Option, + + /// The index of this binding's symbol in the symbol table. + /// + /// The symbol table always points to the currently visible binding for a + /// symbol. Therefore this index is only valid if this binding is not shadowed. + /// In particular, we detect shadowing by checking if the entry in the symbol + /// table at this index does indeed point to this binding. + symbol_index: SymbolIndex, +} + +#[derive(Debug, Clone, Copy)] +struct Scope { + /// The length of the `bindings` stack when this scope was entered. + binding_stack: usize, +} + +type BindingIndex = usize; +type SymbolIndex = usize; + +pub type ScopeDepth = u16; + +/// Error that occurs when trying to resolve an unknown symbol. +#[derive(Debug, Clone, Error)] +#[error("symbol name `{0}` not found in this scope")] +pub struct UnknownSymbolError<'a>(pub Cow<'a, str>); + +/// Error that occurs when trying to introduce a symbol that is already defined in the current scope. +#[derive(Debug, Clone, Error)] +#[error("symbol `{0}` is already defined in this scope")] +pub struct DuplicateSymbolError<'a>(pub Cow<'a, str>, pub NodeId, pub NodeId); diff --git a/hugr-model/src/v0/scope/vars.rs b/hugr-model/src/v0/scope/vars.rs new file mode 100644 index 000000000..b649bc491 --- /dev/null +++ b/hugr-model/src/v0/scope/vars.rs @@ -0,0 +1,155 @@ +use fxhash::FxHasher; +use indexmap::IndexSet; +use std::hash::BuildHasherDefault; +use thiserror::Error; + +use crate::v0::{NodeId, VarId}; + +type FxIndexSet = IndexSet>; + +/// Table for keeping track of node parameters. +/// +/// Variables refer to the parameters of a node which introduces a symbol. +/// Variables have an associated name and are scoped via nodes. The types of +/// parameters of a node may only refer to earlier parameters in the same node +/// in the order they are defined. A variable name must be unique within a +/// single node. Each node that introduces a symbol introduces a new isolated +/// scope for variables. +/// +/// # Examples +/// +/// ``` +/// # pub use hugr_model::v0::{NodeId, VarId}; +/// # pub use hugr_model::v0::scope::VarTable; +/// let mut vars = VarTable::new(); +/// vars.enter(NodeId(0)); +/// vars.insert("foo").unwrap(); +/// assert_eq!(vars.resolve("foo").unwrap(), VarId(NodeId(0), 0)); +/// assert!(!vars.is_visible(VarId(NodeId(0), 1))); +/// vars.insert("bar").unwrap(); +/// assert!(vars.is_visible(VarId(NodeId(0), 1))); +/// assert_eq!(vars.resolve("bar").unwrap(), VarId(NodeId(0), 1)); +/// vars.enter(NodeId(1)); +/// assert!(vars.resolve("foo").is_err()); +/// assert!(!vars.is_visible(VarId(NodeId(0), 0))); +/// vars.exit(); +/// assert_eq!(vars.resolve("foo").unwrap(), VarId(NodeId(0), 0)); +/// assert!(vars.is_visible(VarId(NodeId(0), 0))); +/// ``` +#[derive(Debug, Clone)] +pub struct VarTable<'a> { + vars: FxIndexSet<(NodeId, &'a str)>, + scopes: Vec, +} + +impl<'a> VarTable<'a> { + /// Create a new empty variable table. + pub fn new() -> Self { + Self { + vars: FxIndexSet::default(), + scopes: Vec::new(), + } + } + + /// Enter a new scope for the given node. + pub fn enter(&mut self, node: NodeId) { + self.scopes.push(VarScope { + node, + var_count: 0, + var_stack: self.vars.len(), + }) + } + + /// Exit a previously entered scope. + /// + /// # Panics + /// + /// Panics if there are no open scopes. + pub fn exit(&mut self) { + let scope = self.scopes.pop().unwrap(); + self.vars.drain(scope.var_stack..); + } + + /// Resolve a variable name to a node and variable index. + /// + /// # Errors + /// + /// Returns an error if the variable is not defined in the current scope. + /// + /// # Panics + /// + /// Panics if there are no open scopes. + pub fn resolve(&self, name: &'a str) -> Result> { + let scope = self.scopes.last().unwrap(); + let (set_index, _) = self + .vars + .get_full(&(scope.node, name)) + .ok_or(UnknownVarError(scope.node, name))?; + let var_index = (set_index - scope.var_stack) as u16; + Ok(VarId(scope.node, var_index)) + } + + /// Check if a variable is visible in the current scope. + /// + /// # Panics + /// + /// Panics if there are no open scopes. + pub fn is_visible(&self, var: VarId) -> bool { + let scope = self.scopes.last().unwrap(); + scope.node == var.0 && var.1 < scope.var_count + } + + /// Insert a new variable into the current scope. + /// + /// # Errors + /// + /// Returns an error if the variable is already defined in the current scope. + /// + /// # Panics + /// + /// Panics if there are no open scopes. + pub fn insert(&mut self, name: &'a str) -> Result> { + let scope = self.scopes.last_mut().unwrap(); + let inserted = self.vars.insert((scope.node, name)); + + if !inserted { + return Err(DuplicateVarError(scope.node, name)); + } + + let var_index = scope.var_count; + scope.var_count += 1; + Ok(VarId(scope.node, var_index)) + } + + /// Reset the variable table to an empty state while preserving the allocations. + pub fn clear(&mut self) { + self.vars.clear(); + self.scopes.clear(); + } +} + +impl Default for VarTable<'_> { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Clone)] +struct VarScope { + /// The node that introduces this scope. + node: NodeId, + /// The number of variables in this scope. + var_count: u16, + /// The length of `VarTable::vars` when the scope was opened. + var_stack: usize, +} + +/// Error that occurs when a node defines two parameters with the same name. +#[derive(Debug, Clone, Error)] +#[error("node {0} already has a variable named `{1}`")] +pub struct DuplicateVarError<'a>(NodeId, &'a str); + +/// Error that occurs when a variable is not defined in the current scope. +#[derive(Debug, Clone, Error)] +#[error("can not resolve variable `{1}` in node {0}")] +pub struct UnknownVarError<'a>(NodeId, &'a str); diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index fc52b8271..4fd34f223 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -34,6 +34,7 @@ node = { | node_tail_loop | node_cond | node_tag + | node_import | node_custom } @@ -51,6 +52,7 @@ node_declare_operation = { "(" ~ "declare-operation" ~ operation_header ~ meta* node_tail_loop = { "(" ~ "tail-loop" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } node_cond = { "(" ~ "cond" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } node_tag = { "(" ~ "tag" ~ tag ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } +node_import = { "(" ~ "import" ~ symbol ~ meta* ~ ")" } node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } signature = { "(" ~ "signature" ~ term ~ ")" } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index 8527f1a00..de928affd 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -1,4 +1,5 @@ use bumpalo::{collections::String as BumpString, collections::Vec as BumpVec, Bump}; +use fxhash::FxHashMap; use pest::{ iterators::{Pair, Pairs}, Parser, RuleType, @@ -6,9 +7,10 @@ use pest::{ use thiserror::Error; use crate::v0::{ - AliasDecl, ConstructorDecl, ExtSetPart, FuncDecl, GlobalRef, LinkRef, ListPart, LocalRef, - MetaItem, Module, Node, NodeId, Operation, OperationDecl, Param, ParamSort, Region, RegionId, - RegionKind, Term, TermId, + scope::{LinkTable, SymbolTable, UnknownSymbolError, VarTable}, + AliasDecl, ConstructorDecl, ExtSetPart, FuncDecl, LinkIndex, ListPart, MetaItem, Module, Node, + NodeId, Operation, OperationDecl, Param, ParamSort, Region, RegionId, RegionKind, RegionScope, + Term, TermId, }; mod pest_parser { @@ -50,12 +52,20 @@ pub fn parse<'a>(input: &'a str, bump: &'a Bump) -> Result, Par struct ParseContext<'a> { module: Module<'a>, bump: &'a Bump, + vars: VarTable<'a>, + links: LinkTable<&'a str>, + symbols: SymbolTable<'a>, + implicit_imports: FxHashMap<&'a str, NodeId>, } impl<'a> ParseContext<'a> { fn new(bump: &'a Bump) -> Self { Self { module: Module::default(), + symbols: SymbolTable::default(), + links: LinkTable::default(), + vars: VarTable::default(), + implicit_imports: FxHashMap::default(), bump, } } @@ -63,20 +73,38 @@ impl<'a> ParseContext<'a> { fn parse_module(&mut self, pair: Pair<'a, Rule>) -> ParseResult<()> { debug_assert_eq!(pair.as_rule(), Rule::module); let mut inner = pair.into_inner(); + + self.module.root = self.module.insert_region(Region::default()); + self.symbols.enter(self.module.root); + self.links.enter(self.module.root); + + // TODO: What scope does the metadata live in? let meta = self.parse_meta(&mut inner)?; + let explicit_children = self.parse_nodes(&mut inner)?; - let children = self.parse_nodes(&mut inner)?; + let mut children = BumpVec::with_capacity_in( + explicit_children.len() + self.implicit_imports.len(), + self.bump, + ); + children.extend(explicit_children); + children.extend(self.implicit_imports.drain().map(|(_, node)| node)); + let children = children.into_bump_slice(); + + let (link_count, port_count) = self.links.exit(); + self.symbols.exit(); - let root_region = self.module.insert_region(Region { + self.module.regions[self.module.root.index()] = Region { kind: RegionKind::Module, sources: &[], targets: &[], children, meta, signature: None, - }); - - self.module.root = root_region; + scope: Some(RegionScope { + links: link_count, + ports: port_count, + }), + }; Ok(()) } @@ -87,143 +115,195 @@ impl<'a> ParseContext<'a> { let rule = pair.as_rule(); let mut inner = pair.into_inner(); - let term = match rule { - Rule::term_wildcard => Term::Wildcard, - Rule::term_type => Term::Type, - Rule::term_static => Term::StaticType, - Rule::term_constraint => Term::Constraint, - Rule::term_str_type => Term::StrType, - Rule::term_nat_type => Term::NatType, - Rule::term_ctrl_type => Term::ControlType, - Rule::term_ext_set_type => Term::ExtSetType, - - Rule::term_var => { - let name_token = inner.next().unwrap(); - let name = name_token.as_str(); - Term::Var(LocalRef::Named(name)) - } - - Rule::term_apply => { - let name = GlobalRef::Named(self.parse_symbol(&mut inner)?); - let mut args = Vec::new(); - - for token in inner { - args.push(self.parse_term(token)?); - } + let term = + match rule { + Rule::term_wildcard => Term::Wildcard, + Rule::term_type => Term::Type, + Rule::term_static => Term::StaticType, + Rule::term_constraint => Term::Constraint, + Rule::term_str_type => Term::StrType, + Rule::term_nat_type => Term::NatType, + Rule::term_ctrl_type => Term::ControlType, + Rule::term_ext_set_type => Term::ExtSetType, + + Rule::term_var => { + let name_token = inner.next().unwrap(); + let name = name_token.as_str(); + + let var = self.vars.resolve(name).map_err(|err| { + ParseError::custom(&err.to_string(), name_token.as_span()) + })?; - Term::Apply { - global: name, - args: self.bump.alloc_slice_copy(&args), + Term::Var(var) } - } - Rule::term_apply_full => { - let name = GlobalRef::Named(self.parse_symbol(&mut inner)?); - let mut args = Vec::new(); + Rule::term_apply => { + let symbol = self.parse_symbol_use(&mut inner)?; + let mut args = Vec::new(); - for token in inner { - args.push(self.parse_term(token)?); - } + for token in inner { + args.push(self.parse_term(token)?); + } - Term::ApplyFull { - global: name, - args: self.bump.alloc_slice_copy(&args), + Term::Apply { + symbol, + args: self.bump.alloc_slice_copy(&args), + } } - } - Rule::term_quote => { - let r#type = self.parse_term(inner.next().unwrap())?; - Term::Quote { r#type } - } + Rule::term_apply_full => { + let symbol = self.parse_symbol_use(&mut inner)?; + let mut args = Vec::new(); - Rule::term_list => { - let mut parts = BumpVec::with_capacity_in(inner.len(), self.bump); + for token in inner { + args.push(self.parse_term(token)?); + } - for token in inner { - match token.as_rule() { - Rule::term => parts.push(ListPart::Item(self.parse_term(token)?)), - Rule::spliced_term => { - let term_token = token.into_inner().next().unwrap(); - parts.push(ListPart::Splice(self.parse_term(term_token)?)) - } - _ => unreachable!(), + Term::ApplyFull { + symbol, + args: self.bump.alloc_slice_copy(&args), } } - Term::List { - parts: parts.into_bump_slice(), + Rule::term_quote => { + let r#type = self.parse_term(inner.next().unwrap())?; + Term::Quote { r#type } } - } - Rule::term_list_type => { - let item_type = self.parse_term(inner.next().unwrap())?; - Term::ListType { item_type } - } + Rule::term_list => { + let mut parts = BumpVec::with_capacity_in(inner.len(), self.bump); + + for token in inner { + match token.as_rule() { + Rule::term => parts.push(ListPart::Item(self.parse_term(token)?)), + Rule::spliced_term => { + let term_token = token.into_inner().next().unwrap(); + parts.push(ListPart::Splice(self.parse_term(term_token)?)) + } + _ => unreachable!(), + } + } - Rule::term_str => { - let value = self.parse_string(inner.next().unwrap())?; - Term::Str(value) - } + Term::List { + parts: parts.into_bump_slice(), + } + } - Rule::term_nat => { - let value = inner.next().unwrap().as_str().parse().unwrap(); - Term::Nat(value) - } + Rule::term_list_type => { + let item_type = self.parse_term(inner.next().unwrap())?; + Term::ListType { item_type } + } + + Rule::term_str => { + let value = self.parse_string(inner.next().unwrap())?; + Term::Str(value) + } - Rule::term_ext_set => { - let mut parts = BumpVec::with_capacity_in(inner.len(), self.bump); + Rule::term_nat => { + let value = inner.next().unwrap().as_str().parse().unwrap(); + Term::Nat(value) + } - for token in inner { - match token.as_rule() { - Rule::ext_name => { - parts.push(ExtSetPart::Extension(self.bump.alloc_str(token.as_str()))) - } - Rule::spliced_term => { - let term_token = token.into_inner().next().unwrap(); - parts.push(ExtSetPart::Splice(self.parse_term(term_token)?)) + Rule::term_ext_set => { + let mut parts = BumpVec::with_capacity_in(inner.len(), self.bump); + + for token in inner { + match token.as_rule() { + Rule::ext_name => parts + .push(ExtSetPart::Extension(self.bump.alloc_str(token.as_str()))), + Rule::spliced_term => { + let term_token = token.into_inner().next().unwrap(); + parts.push(ExtSetPart::Splice(self.parse_term(term_token)?)) + } + _ => unreachable!(), } - _ => unreachable!(), + } + + Term::ExtSet { + parts: parts.into_bump_slice(), } } - Term::ExtSet { - parts: parts.into_bump_slice(), + Rule::term_adt => { + let variants = self.parse_term(inner.next().unwrap())?; + Term::Adt { variants } } - } - Rule::term_adt => { - let variants = self.parse_term(inner.next().unwrap())?; - Term::Adt { variants } - } + Rule::term_func_type => { + let inputs = self.parse_term(inner.next().unwrap())?; + let outputs = self.parse_term(inner.next().unwrap())?; + let extensions = self.parse_term(inner.next().unwrap())?; + Term::FuncType { + inputs, + outputs, + extensions, + } + } - Rule::term_func_type => { - let inputs = self.parse_term(inner.next().unwrap())?; - let outputs = self.parse_term(inner.next().unwrap())?; - let extensions = self.parse_term(inner.next().unwrap())?; - Term::FuncType { - inputs, - outputs, - extensions, + Rule::term_ctrl => { + let values = self.parse_term(inner.next().unwrap())?; + Term::Control { values } } - } - Rule::term_ctrl => { - let values = self.parse_term(inner.next().unwrap())?; - Term::Control { values } - } + Rule::term_non_linear => { + let term = self.parse_term(inner.next().unwrap())?; + Term::NonLinearConstraint { term } + } - Rule::term_non_linear => { - let term = self.parse_term(inner.next().unwrap())?; - Term::NonLinearConstraint { term } - } + r => unreachable!("term: {:?}", r), + }; + + Ok(self.module.insert_term(term)) + } + + fn parse_node_shallow(&mut self, pair: Pair<'a, Rule>) -> ParseResult { + debug_assert_eq!(pair.as_rule(), Rule::node); + let pair = pair.into_inner().next().unwrap(); + let span = pair.as_span(); + let rule = pair.as_rule(); + let mut inner = pair.into_inner(); - r => unreachable!("term: {:?}", r), + let symbol = match rule { + Rule::node_define_func => { + let mut func_header = inner.next().unwrap().into_inner(); + Some(self.parse_symbol(&mut func_header)?) + } + Rule::node_declare_func => { + let mut func_header = inner.next().unwrap().into_inner(); + Some(self.parse_symbol(&mut func_header)?) + } + Rule::node_define_alias => { + let mut alias_header = inner.next().unwrap().into_inner(); + Some(self.parse_symbol(&mut alias_header)?) + } + Rule::node_declare_alias => { + let mut alias_header = inner.next().unwrap().into_inner(); + Some(self.parse_symbol(&mut alias_header)?) + } + Rule::node_declare_ctr => { + let mut ctr_header = inner.next().unwrap().into_inner(); + Some(self.parse_symbol(&mut ctr_header)?) + } + Rule::node_declare_operation => { + let mut op_header = inner.next().unwrap().into_inner(); + Some(self.parse_symbol(&mut op_header)?) + } + Rule::node_import => Some(self.parse_symbol(&mut inner)?), + _ => None, }; - Ok(self.module.insert_term(term)) + let node = self.module.insert_node(Node::default()); + + if let Some(symbol) = symbol { + self.symbols + .insert(symbol, node) + .map_err(|err| ParseError::custom(&err.to_string(), span))?; + } + + Ok(node) } - fn parse_node(&mut self, pair: Pair<'a, Rule>) -> ParseResult { + fn parse_node_deep(&mut self, pair: Pair<'a, Rule>, node: NodeId) -> ParseResult> { debug_assert_eq!(pair.as_rule(), Rule::node); let pair = pair.into_inner().next().unwrap(); let rule = pair.as_rule(); @@ -236,7 +316,7 @@ impl<'a> ParseContext<'a> { let outputs = self.parse_port_list(&mut inner)?; let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; - let regions = self.parse_regions(&mut inner)?; + let regions = self.parse_regions(&mut inner, false)?; Node { operation: Operation::Dfg, inputs, @@ -253,7 +333,7 @@ impl<'a> ParseContext<'a> { let outputs = self.parse_port_list(&mut inner)?; let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; - let regions = self.parse_regions(&mut inner)?; + let regions = self.parse_regions(&mut inner, false)?; Node { operation: Operation::Cfg, inputs, @@ -270,7 +350,7 @@ impl<'a> ParseContext<'a> { let outputs = self.parse_port_list(&mut inner)?; let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; - let regions = self.parse_regions(&mut inner)?; + let regions = self.parse_regions(&mut inner, false)?; Node { operation: Operation::Block, inputs, @@ -283,9 +363,11 @@ impl<'a> ParseContext<'a> { } Rule::node_define_func => { + self.vars.enter(node); let decl = self.parse_func_header(inner.next().unwrap())?; let meta = self.parse_meta(&mut inner)?; - let regions = self.parse_regions(&mut inner)?; + let regions = self.parse_regions(&mut inner, true)?; + self.vars.exit(); Node { operation: Operation::DefineFunc { decl }, inputs: &[], @@ -298,8 +380,10 @@ impl<'a> ParseContext<'a> { } Rule::node_declare_func => { + self.vars.enter(node); let decl = self.parse_func_header(inner.next().unwrap())?; let meta = self.parse_meta(&mut inner)?; + self.vars.exit(); Node { operation: Operation::DeclareFunc { decl }, inputs: &[], @@ -346,9 +430,11 @@ impl<'a> ParseContext<'a> { } Rule::node_define_alias => { + self.vars.enter(node); let decl = self.parse_alias_header(inner.next().unwrap())?; let value = self.parse_term(inner.next().unwrap())?; let meta = self.parse_meta(&mut inner)?; + self.vars.exit(); Node { operation: Operation::DefineAlias { decl, value }, inputs: &[], @@ -361,8 +447,10 @@ impl<'a> ParseContext<'a> { } Rule::node_declare_alias => { + self.vars.enter(node); let decl = self.parse_alias_header(inner.next().unwrap())?; let meta = self.parse_meta(&mut inner)?; + self.vars.exit(); Node { operation: Operation::DeclareAlias { decl }, inputs: &[], @@ -383,7 +471,7 @@ impl<'a> ParseContext<'a> { let op_rule = op.as_rule(); let mut op_inner = op.into_inner(); - let name = GlobalRef::Named(self.parse_symbol(&mut op_inner)?); + let operation = self.parse_symbol_use(&mut op_inner)?; let mut params = Vec::new(); @@ -392,8 +480,8 @@ impl<'a> ParseContext<'a> { } let operation = match op_rule { - Rule::term_apply_full => Operation::CustomFull { operation: name }, - Rule::term_apply => Operation::Custom { operation: name }, + Rule::term_apply_full => Operation::CustomFull { operation }, + Rule::term_apply => Operation::Custom { operation }, _ => unreachable!(), }; @@ -401,7 +489,7 @@ impl<'a> ParseContext<'a> { let outputs = self.parse_port_list(&mut inner)?; let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; - let regions = self.parse_regions(&mut inner)?; + let regions = self.parse_regions(&mut inner, true)?; Node { operation, inputs, @@ -418,7 +506,7 @@ impl<'a> ParseContext<'a> { let outputs = self.parse_port_list(&mut inner)?; let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; - let regions = self.parse_regions(&mut inner)?; + let regions = self.parse_regions(&mut inner, true)?; Node { operation: Operation::TailLoop, inputs, @@ -435,7 +523,7 @@ impl<'a> ParseContext<'a> { let outputs = self.parse_port_list(&mut inner)?; let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; - let regions = self.parse_regions(&mut inner)?; + let regions = self.parse_regions(&mut inner, false)?; Node { operation: Operation::Conditional, inputs, @@ -465,8 +553,10 @@ impl<'a> ParseContext<'a> { } Rule::node_declare_ctr => { + self.vars.enter(node); let decl = self.parse_ctr_header(inner.next().unwrap())?; let meta = self.parse_meta(&mut inner)?; + self.vars.exit(); Node { operation: Operation::DeclareConstructor { decl }, inputs: &[], @@ -479,8 +569,10 @@ impl<'a> ParseContext<'a> { } Rule::node_declare_operation => { + self.vars.enter(node); let decl = self.parse_op_header(inner.next().unwrap())?; let meta = self.parse_meta(&mut inner)?; + self.vars.exit(); Node { operation: Operation::DeclareOperation { decl }, inputs: &[], @@ -495,25 +587,34 @@ impl<'a> ParseContext<'a> { _ => unreachable!(), }; - let node_id = self.module.insert_node(node); - - Ok(node_id) + Ok(node) } - fn parse_regions(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [RegionId]> { + fn parse_regions( + &mut self, + pairs: &mut Pairs<'a, Rule>, + closed: bool, + ) -> ParseResult<&'a [RegionId]> { let mut regions = Vec::new(); for pair in filter_rule(pairs, Rule::region) { - regions.push(self.parse_region(pair)?); + regions.push(self.parse_region(pair, closed)?); } Ok(self.bump.alloc_slice_copy(®ions)) } - fn parse_region(&mut self, pair: Pair<'a, Rule>) -> ParseResult { + fn parse_region(&mut self, pair: Pair<'a, Rule>, closed: bool) -> ParseResult { debug_assert_eq!(pair.as_rule(), Rule::region); let pair = pair.into_inner().next().unwrap(); let rule = pair.as_rule(); let mut inner = pair.into_inner(); + let region = self.module.insert_region(Region::default()); + self.symbols.enter(region); + + if closed { + self.links.enter(region); + } + let kind = match rule { Rule::region_cfg => RegionKind::ControlFlow, Rule::region_dfg => RegionKind::DataFlow, @@ -526,24 +627,47 @@ impl<'a> ParseContext<'a> { let meta = self.parse_meta(&mut inner)?; let children = self.parse_nodes(&mut inner)?; - Ok(self.module.insert_region(Region { + let scope = if closed { + let (links, ports) = self.links.exit(); + Some(RegionScope { links, ports }) + } else { + None + }; + + self.symbols.exit(); + + self.module.regions[region.index()] = Region { kind, sources, targets, children, meta, signature, - })) + scope, + }; + + Ok(region) } fn parse_nodes(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [NodeId]> { - let mut nodes = Vec::new(); + let nodes = { + let mut pairs = pairs.clone(); + let mut nodes = BumpVec::with_capacity_in(pairs.len(), self.bump); + + for pair in filter_rule(&mut pairs, Rule::node) { + nodes.push(self.parse_node_shallow(pair)?); + } - for pair in filter_rule(pairs, Rule::node) { - nodes.push(self.parse_node(pair)?); + nodes.into_bump_slice() + }; + + for (i, pair) in filter_rule(pairs, Rule::node).enumerate() { + let node = nodes[i]; + let node_data = self.parse_node_deep(pair, node)?; + self.module.nodes[node.index()] = node_data; } - Ok(self.bump.alloc_slice_copy(&nodes)) + Ok(nodes) } fn parse_func_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a FuncDecl<'a>> { @@ -627,6 +751,7 @@ impl<'a> ParseContext<'a> { for pair in filter_rule(pairs, Rule::param) { let param = pair.into_inner().next().unwrap(); + let param_span = param.as_span(); let param = match param.as_rule() { Rule::param_implicit => { @@ -652,6 +777,10 @@ impl<'a> ParseContext<'a> { _ => unreachable!(), }; + self.vars + .insert(param.name) + .map_err(|err| ParseError::custom(&err.to_string(), param_span))?; + params.push(param); } @@ -679,27 +808,27 @@ impl<'a> ParseContext<'a> { Ok(Some(signature)) } - fn parse_port_list(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [LinkRef<'a>]> { + fn parse_port_list(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [LinkIndex]> { let Some(Rule::port_list) = pairs.peek().map(|p| p.as_rule()) else { return Ok(&[]); }; let pair = pairs.next().unwrap(); let inner = pair.into_inner(); - let mut links = Vec::new(); + let mut links = BumpVec::with_capacity_in(inner.len(), self.bump); for token in inner { links.push(self.parse_port(token)?); } - Ok(self.bump.alloc_slice_copy(&links)) + Ok(links.into_bump_slice()) } - fn parse_port(&mut self, pair: Pair<'a, Rule>) -> ParseResult> { + fn parse_port(&mut self, pair: Pair<'a, Rule>) -> ParseResult { debug_assert_eq!(pair.as_rule(), Rule::port); let mut inner = pair.into_inner(); - let link = LinkRef::Named(&inner.next().unwrap().as_str()[1..]); - Ok(link) + let name = &inner.next().unwrap().as_str()[1..]; + Ok(self.links.use_link(name)) } fn parse_meta(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [MetaItem<'a>]> { @@ -715,6 +844,21 @@ impl<'a> ParseContext<'a> { Ok(self.bump.alloc_slice_copy(&items)) } + fn parse_symbol_use(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult { + let name = self.parse_symbol(pairs)?; + let resolved = self.symbols.resolve(name); + + Ok(match resolved { + Ok(node) => node, + Err(UnknownSymbolError(_)) => *self.implicit_imports.entry(name).or_insert_with(|| { + self.module.insert_node(Node { + operation: Operation::Import { name }, + ..Node::default() + }) + }), + }) + } + fn parse_symbol(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a str> { let pair = pairs.next().unwrap(); if let Rule::symbol = pair.as_rule() { diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index ac35b4cd4..5581ff9fd 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -2,8 +2,8 @@ use pretty::{Arena, DocAllocator, RefDoc}; use std::borrow::Cow; use crate::v0::{ - ExtSetPart, GlobalRef, LinkRef, ListPart, LocalRef, MetaItem, ModelError, Module, NodeId, - Operation, Param, ParamSort, RegionId, RegionKind, Term, TermId, + ExtSetPart, LinkIndex, ListPart, MetaItem, ModelError, Module, NodeId, Operation, Param, + ParamSort, RegionId, RegionKind, Term, TermId, VarId, }; type PrintError = ModelError; @@ -247,10 +247,10 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Operation::Custom { operation } => { this.print_group(|this| { if node_data.params.is_empty() { - this.print_global_ref(*operation)?; + this.print_symbol(*operation)?; } else { this.print_parens(|this| { - this.print_global_ref(*operation)?; + this.print_symbol(*operation)?; for param in node_data.params { this.print_term(*param)?; @@ -271,7 +271,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_group(|this| { this.print_parens(|this| { this.print_text("@"); - this.print_global_ref(*operation)?; + this.print_symbol(*operation)?; for param in node_data.params { this.print_term(*param)?; @@ -364,6 +364,12 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_signature(node_data.signature)?; this.print_meta(node_data.meta) } + + Operation::Import { name } => { + this.print_text("import"); + this.print_text(*name); + this.print_meta(node_data.meta) + } }) } @@ -413,8 +419,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { fn print_port_lists( &mut self, - first: &'a [LinkRef<'a>], - second: &'a [LinkRef<'a>], + first: &'a [LinkIndex], + second: &'a [LinkIndex], ) -> PrintResult<()> { if !first.is_empty() && !second.is_empty() { self.print_group(|this| { @@ -426,20 +432,17 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { } } - fn print_port_list(&mut self, links: &'a [LinkRef<'a>]) -> PrintResult<()> { + fn print_port_list(&mut self, links: &'a [LinkIndex]) -> PrintResult<()> { self.print_brackets(|this| { for link in links { - this.print_link_ref(*link); + this.print_link_index(*link); } Ok(()) }) } - fn print_link_ref(&mut self, link_ref: LinkRef<'a>) { - match link_ref { - LinkRef::Id(link_id) => self.print_text(format!("%{}", link_id.0)), - LinkRef::Named(name) => self.print_text(format!("%{}", name)), - } + fn print_link_index(&mut self, link_index: LinkIndex) { + self.print_text(format!("%{}", link_index.0)); } fn print_params(&mut self, params: &'a [Param<'a>]) -> PrintResult<()> { @@ -492,13 +495,13 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { self.print_text("constraint"); Ok(()) } - Term::Var(local_ref) => self.print_local_ref(*local_ref), - Term::Apply { global: name, args } => { + Term::Var(var) => self.print_var(*var), + Term::Apply { symbol, args } => { if args.is_empty() { - self.print_global_ref(*name)?; + self.print_symbol(*symbol)?; } else { self.print_parens(|this| { - this.print_global_ref(*name)?; + this.print_symbol(*symbol)?; for arg in args.iter() { this.print_term(*arg)?; } @@ -508,9 +511,9 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) } - Term::ApplyFull { global: name, args } => self.print_parens(|this| { + Term::ApplyFull { symbol, args } => self.print_parens(|this| { this.print_text("@"); - this.print_global_ref(*name)?; + this.print_symbol(*symbol)?; for arg in args.iter() { this.print_term(*arg)?; } @@ -628,44 +631,27 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) } - fn print_local_ref(&mut self, local_ref: LocalRef<'a>) -> PrintResult<()> { - let name = match local_ref { - LocalRef::Index(_, i) => { - let Some(name) = self.locals.get(i as usize) else { - return Err(PrintError::InvalidLocal(local_ref.to_string())); - }; - - name - } - LocalRef::Named(name) => name, + fn print_var(&mut self, var: VarId) -> PrintResult<()> { + let Some(name) = self.locals.get(var.1 as usize) else { + return Err(PrintError::InvalidVar(var)); }; self.print_text(format!("?{}", name)); Ok(()) } - fn print_global_ref(&mut self, global_ref: GlobalRef<'a>) -> PrintResult<()> { - match global_ref { - GlobalRef::Direct(node_id) => { - let node_data = self - .module - .get_node(node_id) - .ok_or(PrintError::NodeNotFound(node_id))?; - - let name = match &node_data.operation { - Operation::DefineFunc { decl } => decl.name, - Operation::DeclareFunc { decl } => decl.name, - Operation::DefineAlias { decl, .. } => decl.name, - Operation::DeclareAlias { decl } => decl.name, - _ => return Err(PrintError::UnexpectedOperation(node_id)), - }; - - self.print_text(name) - } + fn print_symbol(&mut self, node_id: NodeId) -> PrintResult<()> { + let node_data = self + .module + .get_node(node_id) + .ok_or(PrintError::NodeNotFound(node_id))?; - GlobalRef::Named(symbol) => self.print_text(symbol), - } + let name = node_data + .operation + .symbol() + .ok_or(PrintError::UnexpectedOperation(node_id))?; + self.print_text(name); Ok(()) }