diff --git a/.github/workflows/ci-py.yml b/.github/workflows/ci-py.yml index 6f9e58e78..f203ff7eb 100644 --- a/.github/workflows/ci-py.yml +++ b/.github/workflows/ci-py.yml @@ -29,11 +29,11 @@ jobs: outputs: python: ${{ github.ref_name == github.event.repository.default_branch || steps.filter.outputs.python }} steps: - - uses: actions/checkout@v4 - - uses: dorny/paths-filter@v3 - id: filter - with: - filters: .github/change-filters.yml + - uses: actions/checkout@v4 + - uses: dorny/paths-filter@v3 + id: filter + with: + filters: .github/change-filters.yml check: needs: changes @@ -44,7 +44,7 @@ jobs: strategy: matrix: - python-version: ['3.10', '3.12'] + python-version: ["3.10", "3.12"] steps: - uses: actions/checkout@v4 @@ -86,6 +86,8 @@ jobs: - uses: mozilla-actions/sccache-action@v0.0.6 - name: Install stable toolchain uses: dtolnay/rust-toolchain@stable + - name: Install CapnProto + run: sudo apt-get install -y capnproto - name: Build HUGR binary run: cargo build -p hugr-cli - name: Upload the binary to the artifacts @@ -102,8 +104,8 @@ jobs: strategy: matrix: python-version: - - { py: '3.10', coverage: false } - - { py: '3.12', coverage: true } + - { py: "3.10", coverage: false } + - { py: "3.12", coverage: true } steps: - uses: actions/checkout@v4 @@ -183,7 +185,6 @@ jobs: exit 1 fi - # This is a meta job to mark successful completion of the required checks, # even if they are skipped due to no changes in the relevant files. required-checks: diff --git a/.github/workflows/ci-rs.yml b/.github/workflows/ci-rs.yml index 67bceaaec..de947d13e 100644 --- a/.github/workflows/ci-rs.yml +++ b/.github/workflows/ci-rs.yml @@ -15,7 +15,7 @@ env: CARGO_TERM_COLOR: always CARGO_INCREMENTAL: 0 RUSTFLAGS: "--cfg=ci_run" - MIRIFLAGS: '-Zmiri-permissive-provenance' # Required due to warnings in bitvec 1.0.1 + MIRIFLAGS: "-Zmiri-permissive-provenance" # Required due to warnings in bitvec 1.0.1 CI: true # insta snapshots behave differently on ci SCCACHE_GHA_ENABLED: "true" RUSTC_WRAPPER: "sccache" @@ -34,11 +34,11 @@ jobs: outputs: rust: ${{ github.ref_name == github.event.repository.default_branch || steps.filter.outputs.rust }} steps: - - uses: actions/checkout@v4 - - uses: dorny/paths-filter@v3 - id: filter - with: - filters: .github/change-filters.yml + - uses: actions/checkout@v4 + - uses: dorny/paths-filter@v3 + id: filter + with: + filters: .github/change-filters.yml check: needs: changes @@ -51,6 +51,8 @@ jobs: uses: dtolnay/rust-toolchain@stable with: components: rustfmt, clippy + - name: Install CapnProto + run: sudo apt-get install -y capnproto - name: Check formatting run: cargo fmt -- --check - name: Run clippy @@ -70,6 +72,8 @@ jobs: - uses: mozilla-actions/sccache-action@v0.0.6 - name: Install stable toolchain uses: dtolnay/rust-toolchain@stable + - name: Install CapnProto + run: sudo apt-get install -y capnproto - name: Build benchmarks with no features run: cargo bench --verbose --no-run --workspace --no-default-features - name: Build benchmarks with all features @@ -87,9 +91,11 @@ jobs: - id: toolchain uses: dtolnay/rust-toolchain@master with: - toolchain: 'stable' + toolchain: "stable" - name: Configure default rust toolchain run: rustup override set ${{steps.toolchain.outputs.name}} + - name: Install CapnProto + run: sudo apt-get install -y capnproto - name: Build with no features run: cargo test --verbose --workspace --no-default-features --no-run - name: Tests with no features @@ -107,9 +113,11 @@ jobs: - id: toolchain uses: dtolnay/rust-toolchain@master with: - toolchain: 'stable' + toolchain: "stable" - name: Configure default rust toolchain run: rustup override set ${{steps.toolchain.outputs.name}} + - name: Install CapnProto + run: sudo apt-get install -y capnproto - name: Build with all features run: cargo test --verbose --workspace --all-features --no-run - name: Tests with all features @@ -132,7 +140,7 @@ jobs: matrix: # Stable is covered by `tests-stable-no-features` and `tests-stable-all-features` # Nightly is covered by `tests-nightly-coverage` - rust: ['1.75', beta] + rust: ["1.75", beta] name: tests (Rust ${{ matrix.rust }}) steps: - uses: actions/checkout@v4 @@ -143,6 +151,8 @@ jobs: toolchain: ${{ matrix.rust }} - name: Configure default rust toolchain run: rustup override set ${{steps.toolchain.outputs.name}} + - name: Install CapnProto + run: sudo apt-get install -y capnproto - name: Build with no features run: cargo test --verbose --workspace --no-default-features --no-run - name: Tests with no features @@ -190,15 +200,17 @@ jobs: - uses: dtolnay/rust-toolchain@master with: # Nightly is required to count doctests coverage - toolchain: 'nightly' + toolchain: "nightly" components: llvm-tools-preview + - name: Install CapnProto + run: sudo apt-get install -y capnproto - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov - name: Run tests with coverage instrumentation run: | - cargo llvm-cov clean --workspace - cargo llvm-cov --no-report --workspace --no-default-features --doctests - cargo llvm-cov --no-report --workspace --all-features --doctests + cargo llvm-cov clean --workspace + cargo llvm-cov --no-report --workspace --no-default-features --doctests + cargo llvm-cov --no-report --workspace --all-features --doctests - name: Generate coverage report run: cargo llvm-cov --all-features report --codecov --output-path coverage.json - name: Upload coverage to codecov.io @@ -213,7 +225,14 @@ jobs: # even if they are skipped due to no changes in the relevant files. required-checks: name: Required checks 🦀 - needs: [changes, check, tests-stable-no-features, tests-stable-all-features, std-extensions] + needs: + [ + changes, + check, + tests-stable-no-features, + tests-stable-all-features, + std-extensions, + ] if: ${{ !cancelled() }} runs-on: ubuntu-latest steps: diff --git a/devenv.nix b/devenv.nix index 92c31d948..f02fc386e 100644 --- a/devenv.nix +++ b/devenv.nix @@ -22,6 +22,7 @@ in pkgs-stable.cargo-llvm-cov pkgs.graphviz pkgs.cargo-insta + pkgs.capnproto ] ++ lib.optionals pkgs.stdenv.isDarwin (with pkgs.darwin.apple_sdk; [ diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 5fd6932eb..0d266db39 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -1,18 +1,23 @@ //! Exporting HUGR graphs to their `hugr-model` representation. use crate::{ extension::ExtensionSet, - ops::OpType, + hugr::IdentList, + ops::{DataflowBlock, OpTrait, OpType}, types::{ type_param::{TypeArgVariable, TypeParam}, type_row::TypeRowBase, - CustomType, EdgeKind, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, - TypeArg, TypeBase, TypeEnum, + CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg, + TypeBase, TypeEnum, }, - Direction, Hugr, HugrView, IncomingPort, Node, Port, PortIndex, + Direction, Hugr, HugrView, IncomingPort, Node, Port, }; -use bumpalo::{collections::Vec as BumpVec, Bump}; +use bumpalo::{collections::String as BumpString, collections::Vec as BumpVec, Bump}; +use fxhash::FxHashMap; use hugr_model::v0::{self as model}; use indexmap::IndexSet; +use std::fmt::Write; + +type FxIndexSet = IndexSet; pub(crate) const OP_FUNC_CALL_INDIRECT: &str = "func.call-indirect"; const TERM_PARAM_TUPLE: &str = "param.tuple"; @@ -32,8 +37,12 @@ struct Context<'a> { module: model::Module<'a>, /// Mapping from ports to link indices. /// This only includes the minimum port among groups of linked ports. - links: IndexSet<(Node, Port)>, + links: FxIndexSet<(Node, Port)>, bump: &'a Bump, + /// Stores the terms that we have already seen to avoid duplicates. + term_map: FxHashMap, model::TermId>, + /// The current scope for local variables. + local_scope: Option, } impl<'a> Context<'a> { @@ -45,14 +54,14 @@ impl<'a> Context<'a> { hugr, module, bump, - links: IndexSet::new(), + links: IndexSet::default(), + term_map: FxHashMap::default(), + local_scope: None, } } /// Exports the root module of the HUGR graph. pub fn export_root(&mut self) { - let signature = self.module.insert_term(model::Term::Wildcard); - let hugr_children = self.hugr.children(self.hugr.root()); let mut children = BumpVec::with_capacity_in(hugr_children.len(), self.bump); @@ -66,7 +75,7 @@ impl<'a> Context<'a> { targets: &[], children: children.into_bump_slice(), meta: &[], // TODO: Export metadata - signature, + signature: None, }); self.module.root = root; @@ -75,9 +84,10 @@ impl<'a> Context<'a> { /// Returns the edge id for a given port, creating a new edge if necessary. /// /// Any two ports that are linked will be represented by the same link. - fn get_link_id(&mut self, node: Node, port: Port) -> model::LinkId { + fn get_link_id(&mut self, node: Node, port: impl Into) -> model::LinkId { // To ensure that linked ports are represented by the same edge, we take the minimum port // among all the linked ports, including the one we started with. + let port = port.into(); let linked_ports = self.hugr.linked_ports(node, port); let all_ports = std::iter::once((node, port)).chain(linked_ports); let repr = all_ports.min().unwrap(); @@ -85,68 +95,43 @@ impl<'a> Context<'a> { model::LinkId(edge) } - pub fn make_ports(&mut self, node: Node, direction: Direction) -> &'a [model::Port<'a>] { + pub fn make_ports( + &mut self, + node: Node, + direction: Direction, + num_ports: usize, + ) -> &'a [model::LinkRef<'a>] { let ports = self.hugr.node_ports(node, direction); - let mut model_ports = BumpVec::with_capacity_in(ports.len(), self.bump); + let mut links = BumpVec::with_capacity_in(ports.len(), self.bump); - for port in ports { - if let Some(model_port) = self.make_port(node, port) { - model_ports.push(model_port); - } + for port in ports.take(num_ports) { + links.push(model::LinkRef::Id(self.get_link_id(node, port))); } - model_ports.into_bump_slice() + links.into_bump_slice() } - pub fn make_port(&mut self, node: Node, port: impl Into) -> Option> { - let port: Port = port.into(); - let op_type = self.hugr.get_optype(node); - - let r#type = match op_type.port_kind(port)? { - EdgeKind::ControlFlow => { - // TODO: This should ideally be reported by the op itself - let types: Vec<_> = match (op_type, port.direction()) { - (OpType::DataflowBlock(block), Direction::Incoming) => { - block.inputs.iter().map(|t| self.export_type(t)).collect() - } - (OpType::DataflowBlock(block), Direction::Outgoing) => { - let mut types = Vec::new(); - types.extend( - block.sum_rows[port.index()] - .iter() - .map(|t| self.export_type(t)), - ); - types.extend(block.other_outputs.iter().map(|t| self.export_type(t))); - types - } - (OpType::ExitBlock(block), Direction::Incoming) => block - .cfg_outputs - .iter() - .map(|t| self.export_type(t)) - .collect(), - (OpType::ExitBlock(_), Direction::Outgoing) => vec![], - _ => unreachable!("unexpected control flow port on non-control-flow op"), - }; - - let types = self.bump.alloc_slice_copy(&types); - let values = self.module.insert_term(model::Term::List { - items: types, - tail: None, - }); - self.module.insert_term(model::Term::Control { values }) - } - EdgeKind::Value(r#type) => self.export_type(&r#type), - EdgeKind::Const(_) => return None, - EdgeKind::Function(_) => return None, - EdgeKind::StateOrder => return None, - }; + pub fn make_term(&mut self, term: model::Term<'a>) -> model::TermId { + // Wildcard terms do not all represent the same term, so we should not deduplicate them. + if term == model::Term::Wildcard { + return self.module.insert_term(term); + } - let link = model::LinkRef::Id(self.get_link_id(node, port)); + *self + .term_map + .entry(term.clone()) + .or_insert_with(|| self.module.insert_term(term)) + } - Some(model::Port { - r#type: Some(r#type), - link, - }) + pub fn make_named_global_ref( + &mut self, + extension: &IdentList, + name: impl AsRef, + ) -> model::GlobalRef<'a> { + let capacity = extension.len() + name.as_ref().len() + 1; + let mut output = BumpString::with_capacity_in(capacity, self.bump); + let _ = write!(&mut output, "{}.{}", extension, name.as_ref()); + model::GlobalRef::Named(output.into_bump_str()) } /// Get the node that declares or defines the function associated with the given @@ -171,19 +156,25 @@ impl<'a> Context<'a> { } } + fn with_local_scope(&mut self, node: model::NodeId, f: impl FnOnce(&mut Self) -> T) -> T { + let old_scope = self.local_scope.replace(node); + let result = f(self); + self.local_scope = old_scope; + result + } + pub fn export_node(&mut self, node: Node) -> model::NodeId { - let inputs = self.make_ports(node, Direction::Incoming); - let outputs = self.make_ports(node, Direction::Outgoing); + // We insert a dummy node with the invalid operation at this point to reserve + // the node id. This is necessary to establish the correct node id for the + // local scope introduced by some operations. We will overwrite this node later. + let node_id = self.module.insert_node(model::Node::default()); + let mut params: &[_] = &[]; let mut regions: &[_] = &[]; - fn make_custom(name: &'static str) -> model::Operation { - model::Operation::Custom { - operation: model::GlobalRef::Named(name), - } - } + let optype = self.hugr.get_optype(node); - let operation = match self.hugr.get_optype(node) { + let operation = match optype { OpType::Module(_) => todo!("this should be an error"), OpType::Input(_) => { @@ -194,8 +185,11 @@ impl<'a> Context<'a> { panic!("output nodes should have been handled by the region export") } - OpType::DFG(_) => { - regions = self.bump.alloc_slice_copy(&[self.export_dfg(node)]); + OpType::DFG(dfg) => { + let extensions = self.export_ext_set(&dfg.signature.extension_reqs); + regions = self + .bump + .alloc_slice_copy(&[self.export_dfg(node, extensions)]); model::Operation::Dfg } @@ -212,56 +206,62 @@ impl<'a> Context<'a> { todo!("case nodes should have been handled by the region export") } - OpType::DataflowBlock(_) => { - regions = self.bump.alloc_slice_copy(&[self.export_dfg(node)]); + OpType::DataflowBlock(block) => { + let extensions = self.export_ext_set(&block.extension_delta); + regions = self + .bump + .alloc_slice_copy(&[self.export_dfg(node, extensions)]); model::Operation::Block } - OpType::FuncDefn(func) => { - let name = self.get_func_name(node).unwrap(); - let (params, func) = self.export_poly_func_type(&func.signature); - let decl = self.bump.alloc(model::FuncDecl { + OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| { + let name = this.get_func_name(node).unwrap(); + let (params, signature) = this.export_poly_func_type(&func.signature); + let decl = this.bump.alloc(model::FuncDecl { name, params, - signature: func, + signature, }); - regions = self.bump.alloc_slice_copy(&[self.export_dfg(node)]); + let extensions = this.export_ext_set(&func.signature.body().extension_reqs); + regions = this + .bump + .alloc_slice_copy(&[this.export_dfg(node, extensions)]); model::Operation::DefineFunc { decl } - } + }), - OpType::FuncDecl(func) => { - let name = self.get_func_name(node).unwrap(); - let (params, func) = self.export_poly_func_type(&func.signature); - let decl = self.bump.alloc(model::FuncDecl { + OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| { + let name = this.get_func_name(node).unwrap(); + let (params, func) = this.export_poly_func_type(&func.signature); + let decl = this.bump.alloc(model::FuncDecl { name, params, signature: func, }); model::Operation::DeclareFunc { decl } - } + }), - OpType::AliasDecl(alias) => { + OpType::AliasDecl(alias) => self.with_local_scope(node_id, |this| { // TODO: We should support aliases with different types and with parameters - let r#type = self.module.insert_term(model::Term::Type); - let decl = self.bump.alloc(model::AliasDecl { + let r#type = this.make_term(model::Term::Type); + let decl = this.bump.alloc(model::AliasDecl { name: &alias.name, params: &[], r#type, }); model::Operation::DeclareAlias { decl } - } + }), - OpType::AliasDefn(alias) => { - let value = self.export_type(&alias.definition); + OpType::AliasDefn(alias) => self.with_local_scope(node_id, |this| { + let value = this.export_type(&alias.definition); // TODO: We should support aliases with different types and with parameters - let r#type = self.module.insert_term(model::Term::Type); - let decl = self.bump.alloc(model::AliasDecl { + let r#type = this.make_term(model::Term::Type); + let decl = this.bump.alloc(model::AliasDecl { name: &alias.name, params: &[], r#type, }); model::Operation::DefineAlias { decl, value } - } + }), OpType::Call(call) => { // TODO: If the node is not connected to a function, we should do better than panic. @@ -272,9 +272,7 @@ impl<'a> Context<'a> { args.extend(call.type_args.iter().map(|arg| self.export_type_arg(arg))); let args = args.into_bump_slice(); - let func = self - .module - .insert_term(model::Term::ApplyFull { global: name, args }); + let func = self.make_term(model::Term::ApplyFull { global: name, args }); model::Operation::CallFunc { func } } @@ -287,21 +285,24 @@ impl<'a> Context<'a> { args.extend(load.type_args.iter().map(|arg| self.export_type_arg(arg))); let args = args.into_bump_slice(); - let func = self - .module - .insert_term(model::Term::ApplyFull { global: name, args }); + let func = self.make_term(model::Term::ApplyFull { global: name, args }); model::Operation::LoadFunc { func } } OpType::Const(_) => todo!("Export const nodes?"), OpType::LoadConstant(_) => todo!("Export load constant?"), - OpType::CallIndirect(_) => make_custom(OP_FUNC_CALL_INDIRECT), + OpType::CallIndirect(_) => model::Operation::CustomFull { + operation: model::GlobalRef::Named(OP_FUNC_CALL_INDIRECT), + }, OpType::Tag(tag) => model::Operation::Tag { tag: tag.tag as _ }, - OpType::TailLoop(_) => { - regions = self.bump.alloc_slice_copy(&[self.export_dfg(node)]); + OpType::TailLoop(tail_loop) => { + let extensions = self.export_ext_set(&tail_loop.extension_delta); + regions = self + .bump + .alloc_slice_copy(&[self.export_dfg(node, extensions)]); model::Operation::TailLoop } @@ -314,43 +315,71 @@ impl<'a> Context<'a> { // regions of potentially different kinds. At the moment, we check if the node has any // children, in which case we create a dataflow region with those children. OpType::ExtensionOp(op) => { - let name = - self.bump - .alloc_str(&format!("{}.{}", op.def().extension(), op.def().name())); - let operation = model::GlobalRef::Named(name); + let operation = self.make_named_global_ref(op.def().extension(), op.def().name()); params = self .bump .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg))); - if let Some(region) = self.export_dfg_if_present(node) { + // PERFORMANCE: Currently the API does not appear to allow to get the extension + // set without copying it. + // NOTE: We assume here that the extension set of the dfg region must be the same + // as that of the node. This might change in the future. + let extensions = self.export_ext_set(&op.extension_delta()); + + if let Some(region) = self.export_dfg_if_present(node, extensions) { regions = self.bump.alloc_slice_copy(&[region]); } - model::Operation::Custom { operation } + model::Operation::CustomFull { operation } } OpType::OpaqueOp(op) => { - let name = self - .bump - .alloc_str(&format!("{}.{}", op.extension(), op.op_name())); - let operation = model::GlobalRef::Named(name); + let operation = self.make_named_global_ref(op.extension(), op.op_name()); params = self .bump .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg))); - if let Some(region) = self.export_dfg_if_present(node) { + // PERFORMANCE: Currently the API does not appear to allow to get the extension + // set without copying it. + // NOTE: We assume here that the extension set of the dfg region must be the same + // as that of the node. This might change in the future. + let extensions = self.export_ext_set(&op.extension_delta()); + + if let Some(region) = self.export_dfg_if_present(node, extensions) { regions = self.bump.alloc_slice_copy(&[region]); } - model::Operation::Custom { operation } + model::Operation::CustomFull { operation } + } + }; + + let (signature, num_inputs, num_outputs) = match optype { + OpType::DataflowBlock(block) => { + let signature = self.export_block_signature(block); + (Some(signature), 1, block.sum_rows.len()) } + + // PERFORMANCE: As it stands, `OpType::dataflow_signature` copies and/or allocates. + // That might not seem like a big deal, but it's a significant portion of the time spent + // when exporting. However it is not trivial to change this at the moment. + _ => match &optype.dataflow_signature() { + Some(signature) => { + let num_inputs = signature.input_types().len(); + let num_outputs = signature.output_types().len(); + let signature = self.export_func_type(signature); + (Some(signature), num_inputs, num_outputs) + } + None => (None, 0, 0), + }, }; - let signature = self.module.insert_term(model::Term::Wildcard); + let inputs = self.make_ports(node, Direction::Incoming, num_inputs); + let outputs = self.make_ports(node, Direction::Outgoing, num_outputs); - self.module.insert_node(model::Node { + // Replace the placeholder node with the actual node. + *self.module.get_node_mut(node_id).unwrap() = model::Node { operation, inputs, outputs, @@ -358,38 +387,102 @@ impl<'a> Context<'a> { regions, meta: &[], // TODO: Export metadata signature, + }; + + node_id + } + + /// Export the signature of a `DataflowBlock`. Here we can't use `OpType::dataflow_signature` + /// like for the other nodes since the ports are control flow ports. + pub fn export_block_signature(&mut self, block: &DataflowBlock) -> model::TermId { + let inputs = { + let mut inputs = BumpVec::with_capacity_in(block.inputs.len(), self.bump); + for input in block.inputs.iter() { + inputs.push(self.export_type(input)); + } + let inputs = self.make_term(model::Term::List { + items: inputs.into_bump_slice(), + tail: None, + }); + let inputs = self.make_term(model::Term::Control { values: inputs }); + self.make_term(model::Term::List { + items: self.bump.alloc_slice_copy(&[inputs]), + tail: None, + }) + }; + + let tail = { + let mut tail = BumpVec::with_capacity_in(block.other_outputs.len(), self.bump); + for other_output in block.other_outputs.iter() { + tail.push(self.export_type(other_output)); + } + self.make_term(model::Term::List { + items: tail.into_bump_slice(), + tail: None, + }) + }; + + let outputs = { + let mut outputs = BumpVec::with_capacity_in(block.sum_rows.len(), self.bump); + for sum_row in block.sum_rows.iter() { + let mut variant = BumpVec::with_capacity_in(sum_row.len(), self.bump); + for typ in sum_row.iter() { + variant.push(self.export_type(typ)); + } + let variant = self.make_term(model::Term::List { + items: variant.into_bump_slice(), + tail: Some(tail), + }); + outputs.push(self.make_term(model::Term::Control { values: variant })); + } + self.make_term(model::Term::List { + items: outputs.into_bump_slice(), + tail: None, + }) + }; + + let extensions = self.export_ext_set(&block.extension_delta); + self.make_term(model::Term::FuncType { + inputs, + outputs, + extensions, }) } /// Create a region from the given node's children, if it has any. /// /// See [`Self::export_dfg`]. - pub fn export_dfg_if_present(&mut self, node: Node) -> Option { + pub fn export_dfg_if_present( + &mut self, + node: Node, + extensions: model::TermId, + ) -> Option { if self.hugr.children(node).next().is_none() { None } else { - Some(self.export_dfg(node)) + Some(self.export_dfg(node, extensions)) } } /// Creates a data flow region from the given node's children. /// /// `Input` and `Output` nodes are used to determine the source and target ports of the region. - pub fn export_dfg(&mut self, node: Node) -> model::RegionId { + pub fn export_dfg(&mut self, node: Node, extensions: model::TermId) -> model::RegionId { let mut children = self.hugr.children(node); // The first child is an `Input` node, which we use to determine the region's sources. let input_node = children.next().unwrap(); - assert!(matches!(self.hugr.get_optype(input_node), OpType::Input(_))); - let sources = self.make_ports(input_node, Direction::Outgoing); + let OpType::Input(input_op) = self.hugr.get_optype(input_node) else { + panic!("expected an `Input` node as the first child node"); + }; + let sources = self.make_ports(input_node, Direction::Outgoing, input_op.types.len()); // The second child is an `Output` node, which we use to determine the region's targets. let output_node = children.next().unwrap(); - assert!(matches!( - self.hugr.get_optype(output_node), - OpType::Output(_) - )); - let targets = self.make_ports(output_node, Direction::Incoming); + let OpType::Output(output_op) = self.hugr.get_optype(output_node) else { + panic!("expected an `Output` node as the second child node"); + }; + let targets = self.make_ports(output_node, Direction::Incoming, output_op.types.len()); // Export the remaining children of the node. let mut region_children = BumpVec::with_capacity_in(children.len(), self.bump); @@ -398,8 +491,16 @@ impl<'a> Context<'a> { region_children.push(self.export_node(child)); } - // TODO: We can determine the type of the region - let signature = self.module.insert_term(model::Term::Wildcard); + let signature = { + let inputs = self.export_type_row(&input_op.types); + let outputs = self.export_type_row(&output_op.types); + + Some(self.make_term(model::Term::FuncType { + inputs, + outputs, + extensions, + })) + }; self.module.insert_region(model::Region { kind: model::RegionKind::DataFlow, @@ -421,12 +522,11 @@ impl<'a> Context<'a> { // first input port of the exported entry block. let entry_block = children.next().unwrap(); - assert!(matches!( - self.hugr.get_optype(entry_block), - OpType::DataflowBlock(_) - )); + let OpType::DataflowBlock(_) = self.hugr.get_optype(entry_block) else { + panic!("expected a `DataflowBlock` node as the first child node"); + }; - let source = self.make_port(entry_block, IncomingPort::from(0)).unwrap(); + let source = model::LinkRef::Id(self.get_link_id(entry_block, IncomingPort::from(0))); region_children.push(self.export_node(entry_block)); // Export the remaining children of the node, except for the last one. @@ -440,15 +540,15 @@ impl<'a> Context<'a> { // as the target ports of the control flow region. let exit_block = children.next().unwrap(); - assert!(matches!( - self.hugr.get_optype(exit_block), - OpType::ExitBlock(_) - )); + let OpType::ExitBlock(_) = self.hugr.get_optype(exit_block) else { + panic!("expected an `ExitBlock` node as the last child node"); + }; - let targets = self.make_ports(exit_block, Direction::Incoming); + let targets = self.make_ports(exit_block, Direction::Incoming, 1); - // TODO: We can determine the type of the region - let signature = self.module.insert_term(model::Term::Wildcard); + // Get the signature of the control flow region. + // This is the same as the signature of the parent node. + let signature = Some(self.export_func_type(&self.hugr.signature(node).unwrap())); self.module.insert_region(model::Region { kind: model::RegionKind::ControlFlow, @@ -466,8 +566,12 @@ impl<'a> Context<'a> { let mut regions = BumpVec::with_capacity_in(children.len(), self.bump); for child in children { - assert!(matches!(self.hugr.get_optype(child), OpType::Case(_))); - regions.push(self.export_dfg(child)); + let OpType::Case(case_op) = self.hugr.get_optype(child) else { + panic!("expected a `Case` node as a child of a `Conditional` node"); + }; + + let extensions = self.export_ext_set(&case_op.signature.extension_reqs); + regions.push(self.export_dfg(child, extensions)); } regions.into_bump_slice() @@ -501,14 +605,13 @@ impl<'a> Context<'a> { TypeEnum::Alias(alias) => { let name = model::GlobalRef::Named(self.bump.alloc_str(alias.name())); let args = &[]; - self.module - .insert_term(model::Term::ApplyFull { global: name, args }) + self.make_term(model::Term::ApplyFull { global: name, args }) } TypeEnum::Function(func) => self.export_func_type(func), TypeEnum::Variable(index, _) => { // This ignores the type bound for now - self.module - .insert_term(model::Term::Var(model::LocalRef::Index(*index as _))) + let node = self.local_scope.expect("local variable out of scope"); + self.make_term(model::Term::Var(model::LocalRef::Index(node, *index as _))) } TypeEnum::RowVar(rv) => self.export_row_var(rv.as_rv()), TypeEnum::Sum(sum) => self.export_sum_type(sum), @@ -519,7 +622,7 @@ impl<'a> Context<'a> { let inputs = self.export_type_row(t.input()); let outputs = self.export_type_row(t.output()); let extensions = self.export_ext_set(&t.extension_reqs); - self.module.insert_term(model::Term::FuncType { + self.make_term(model::Term::FuncType { inputs, outputs, extensions, @@ -527,30 +630,26 @@ impl<'a> Context<'a> { } pub fn export_custom_type(&mut self, t: &CustomType) -> model::TermId { - let name = format!("{}.{}", t.extension(), t.name()); - let name = model::GlobalRef::Named(self.bump.alloc_str(&name)); + let global = self.make_named_global_ref(t.extension(), t.name()); let args = self .bump .alloc_slice_fill_iter(t.args().iter().map(|p| self.export_type_arg(p))); - let term = model::Term::ApplyFull { global: name, args }; - self.module.insert_term(term) + let term = model::Term::ApplyFull { global, args }; + self.make_term(term) } pub fn export_type_arg(&mut self, t: &TypeArg) -> model::TermId { match t { TypeArg::Type { ty } => self.export_type(ty), - TypeArg::BoundedNat { n } => self.module.insert_term(model::Term::Nat(*n)), - TypeArg::String { arg } => self - .module - .insert_term(model::Term::Str(self.bump.alloc_str(arg))), + TypeArg::BoundedNat { n } => self.make_term(model::Term::Nat(*n)), + TypeArg::String { arg } => self.make_term(model::Term::Str(self.bump.alloc_str(arg))), TypeArg::Sequence { elems } => { // For now we assume that the sequence is meant to be a list. let items = self .bump .alloc_slice_fill_iter(elems.iter().map(|elem| self.export_type_arg(elem))); - self.module - .insert_term(model::Term::List { items, tail: None }) + self.make_term(model::Term::List { items, tail: None }) } TypeArg::Extensions { es } => self.export_ext_set(es), TypeArg::Variable { v } => self.export_type_arg_var(v), @@ -558,35 +657,38 @@ impl<'a> Context<'a> { } pub fn export_type_arg_var(&mut self, var: &TypeArgVariable) -> model::TermId { - self.module - .insert_term(model::Term::Var(model::LocalRef::Index(var.index() as _))) + let node = self.local_scope.expect("local variable out of scope"); + self.make_term(model::Term::Var(model::LocalRef::Index( + node, + var.index() as _, + ))) } pub fn export_row_var(&mut self, t: &RowVariable) -> model::TermId { - self.module - .insert_term(model::Term::Var(model::LocalRef::Index(t.0 as _))) + let node = self.local_scope.expect("local variable out of scope"); + self.make_term(model::Term::Var(model::LocalRef::Index(node, t.0 as _))) } pub fn export_sum_type(&mut self, t: &SumType) -> model::TermId { match t { SumType::Unit { size } => { let items = self.bump.alloc_slice_fill_iter((0..*size).map(|_| { - self.module.insert_term(model::Term::List { + self.make_term(model::Term::List { items: &[], tail: None, }) })); let list = model::Term::List { items, tail: None }; - let variants = self.module.insert_term(list); - self.module.insert_term(model::Term::Adt { variants }) + let variants = self.make_term(list); + self.make_term(model::Term::Adt { variants }) } SumType::General { rows } => { let items = self .bump .alloc_slice_fill_iter(rows.iter().map(|row| self.export_type_row(row))); let list = model::Term::List { items, tail: None }; - let variants = { self.module.insert_term(list) }; - self.module.insert_term(model::Term::Adt { variants }) + let variants = { self.make_term(list) }; + self.make_term(model::Term::Adt { variants }) } } } @@ -595,36 +697,33 @@ impl<'a> Context<'a> { let mut items = BumpVec::with_capacity_in(t.len(), self.bump); items.extend(t.iter().map(|row| self.export_type(row))); let items = items.into_bump_slice(); - self.module - .insert_term(model::Term::List { items, tail: None }) + self.make_term(model::Term::List { items, tail: None }) } pub fn export_type_param(&mut self, t: &TypeParam) -> model::TermId { match t { // This ignores the type bound for now. - TypeParam::Type { .. } => self.module.insert_term(model::Term::Type), + TypeParam::Type { .. } => self.make_term(model::Term::Type), // This ignores the type bound for now. - TypeParam::BoundedNat { .. } => self.module.insert_term(model::Term::NatType), - TypeParam::String => self.module.insert_term(model::Term::StrType), + TypeParam::BoundedNat { .. } => self.make_term(model::Term::NatType), + TypeParam::String => self.make_term(model::Term::StrType), TypeParam::List { param } => { let item_type = self.export_type_param(param); - self.module.insert_term(model::Term::ListType { item_type }) + self.make_term(model::Term::ListType { item_type }) } TypeParam::Tuple { params } => { let items = self.bump.alloc_slice_fill_iter( params.iter().map(|param| self.export_type_param(param)), ); - let types = self - .module - .insert_term(model::Term::List { items, tail: None }); - self.module.insert_term(model::Term::ApplyFull { + let types = self.make_term(model::Term::List { items, tail: None }); + self.make_term(model::Term::ApplyFull { global: model::GlobalRef::Named(TERM_PARAM_TUPLE), args: self.bump.alloc_slice_copy(&[types]), }) } TypeParam::Extensions => { let term = model::Term::ExtSetType; - self.module.insert_term(term) + self.make_term(term) } } } @@ -664,9 +763,10 @@ impl<'a> Context<'a> { panic!("Extension set with multiple variables") } + let node = self.local_scope.expect("local variable out of scope"); rest = Some( self.module - .insert_term(model::Term::Var(model::LocalRef::Index(index as _))), + .insert_term(model::Term::Var(model::LocalRef::Index(node, index as _))), ); } else { extensions.push(self.bump.alloc_str(ext) as &str); @@ -675,8 +775,7 @@ impl<'a> Context<'a> { let extensions = extensions.into_bump_slice(); - self.module - .insert_term(model::Term::ExtSet { extensions, rest }) + self.make_term(model::Term::ExtSet { extensions, rest }) } } diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 6836eb1f5..0c7239592 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -26,6 +26,8 @@ use itertools::Either; use smol_str::{SmolStr, ToSmolStr}; use thiserror::Error; +type FxIndexMap = IndexMap; + /// Error during import. #[derive(Debug, Clone, Error)] pub enum ImportError { @@ -76,6 +78,7 @@ pub fn import_hugr( extensions, nodes: FxHashMap::default(), local_variables: IndexMap::default(), + custom_name_cache: FxHashMap::default(), }; ctx.import_root()?; @@ -111,31 +114,19 @@ struct Context<'a> { nodes: FxHashMap, /// The types of the local variables that are currently in scope. - local_variables: IndexMap<&'a str, model::TermId>, + local_variables: FxIndexMap<&'a str, model::TermId>, + + custom_name_cache: FxHashMap<&'a str, (ExtensionId, SmolStr)>, } impl<'a> Context<'a> { - /// Get the types of the given ports and assemble them into a `TypeRow`. - fn get_port_types(&mut self, ports: &[model::Port]) -> Result { - let types = ports - .iter() - .map(|port| match port.r#type { - Some(r#type) => self.import_type(r#type), - None => Err(error_uninferred!("port type")), - }) - .collect::, _>>()?; - - Ok(types.into()) - } - - /// Get the signature of the node with the given `NodeId`, using the type information - /// attached to the node's ports in the module. + /// Get the signature of the node with the given `NodeId`. fn get_node_signature(&mut self, node: model::NodeId) -> Result { - let node = self.get_node(node)?; - let inputs = self.get_port_types(node.inputs)?; - let outputs = self.get_port_types(node.outputs)?; - // This creates a signature with empty extension set. - Ok(Signature::new(inputs, outputs)) + let node_data = self.get_node(node)?; + let signature = node_data + .signature + .ok_or_else(|| error_uninferred!("node signature"))?; + self.import_func_type(signature) } /// Get the node with the given `NodeId`, or return an error if it does not exist. @@ -168,7 +159,7 @@ impl<'a> Context<'a> { local_ref: &model::LocalRef, ) -> Result<(usize, model::TermId), ImportError> { let term = match local_ref { - model::LocalRef::Index(index) => self + model::LocalRef::Index(_, index) => self .local_variables .get_index(*index as usize) .map(|(_, term)| (*index as usize, *term)), @@ -197,46 +188,16 @@ impl<'a> Context<'a> { } /// Associate links with the ports of the given node in the given direction. - fn record_links(&mut self, node: Node, direction: Direction, ports: &'a [model::Port]) { + fn record_links(&mut self, node: Node, direction: Direction, links: &'a [model::LinkRef<'a>]) { let optype = self.hugr.get_optype(node); - // NOTE: `OpType::port_count` copies the signature, which significantly slows down the import. - debug_assert!(ports.len() <= optype.port_count(direction)); + debug_assert!(links.len() <= optype.port_count(direction)); - for (model_port, port) in ports.iter().zip(self.hugr.node_ports(node, direction)) { - self.link_ports - .entry(model_port.link) - .or_default() - .push((node, port)); + for (link, port) in links.iter().zip(self.hugr.node_ports(node, direction)) { + self.link_ports.entry(*link).or_default().push((node, port)); } } - fn make_input_node( - &mut self, - parent: Node, - ports: &'a [model::Port], - ) -> Result { - let types = self.get_port_types(ports)?; - let node = self - .hugr - .add_node_with_parent(parent, OpType::Input(Input { types })); - self.record_links(node, Direction::Outgoing, ports); - Ok(node) - } - - fn make_output_node( - &mut self, - parent: Node, - ports: &'a [model::Port], - ) -> Result { - let types = self.get_port_types(ports)?; - let node = self - .hugr - .add_node_with_parent(parent, OpType::Output(Output { types })); - self.record_links(node, Direction::Incoming, ports); - Ok(node) - } - /// Link up the ports in the hugr graph, according to the connectivity information that /// has been gathered in the `link_ports` map. fn link_ports(&mut self) -> Result<(), ImportError> { @@ -377,6 +338,7 @@ impl<'a> Context<'a> { let node_data = self.get_node(node_id)?; match node_data.operation { + model::Operation::Invalid => Err(model::ModelError::InvalidOperation(node_id).into()), model::Operation::Dfg => { let signature = self.get_node_signature(node_id)?; let optype = OpType::DFG(DFG { signature }); @@ -569,7 +531,11 @@ impl<'a> Context<'a> { }), model::Operation::Tag { tag } => { - let (variants, _) = self.import_adt_and_rest(node_id, node_data.outputs)?; + let signature = node_data + .signature + .ok_or_else(|| error_uninferred!("node signature"))?; + let (_, outputs, _) = self.get_func_type(signature)?; + let (variants, _) = self.import_adt_and_rest(node_id, outputs)?; self.make_node( node_id, OpType::Tag(Tag { @@ -594,8 +560,29 @@ impl<'a> Context<'a> { return Err(model::ModelError::InvalidRegions(node_id).into()); } - self.make_input_node(node, region_data.sources)?; - self.make_output_node(node, region_data.targets)?; + let signature = self.import_func_type( + region_data + .signature + .ok_or_else(|| error_uninferred!("region signature"))?, + )?; + + // Create the input and output nodes + let input = self.hugr.add_node_with_parent( + node, + OpType::Input(Input { + types: signature.input, + }), + ); + let output = self.hugr.add_node_with_parent( + node, + OpType::Output(Output { + types: signature.output, + }), + ); + + // Make sure that the ports of the input/output nodes are connected correctly + self.record_links(input, Direction::Outgoing, region_data.sources); + self.record_links(output, Direction::Incoming, region_data.targets); for child in region_data.children { self.import_node(*child, node)?; @@ -607,25 +594,27 @@ impl<'a> Context<'a> { fn import_adt_and_rest( &mut self, node_id: model::NodeId, - ports: &'a [model::Port<'a>], + list: model::TermId, ) -> Result<(Vec, TypeRow), ImportError> { - let Some((first, rest)) = ports.split_first() else { + let items = self.import_closed_list(list)?; + + let Some((first, rest)) = items.split_first() else { return Err(model::ModelError::InvalidRegions(node_id).into()); }; let sum_rows: Vec<_> = { - let Some(term) = first.r#type else { - return Err(error_uninferred!("port type")); - }; - - let model::Term::Adt { variants } = self.get_term(term)? else { - return Err(model::ModelError::TypeError(term).into()); + let model::Term::Adt { variants } = self.get_term(*first)? else { + return Err(model::ModelError::TypeError(*first).into()); }; self.import_type_rows(*variants)? }; - let rest = self.get_port_types(rest)?; + let rest = rest + .iter() + .map(|term| self.import_type(*term)) + .collect::, _>>()? + .into(); Ok((sum_rows, rest)) } @@ -643,20 +632,22 @@ impl<'a> Context<'a> { }; let region_data = self.get_region(*region)?; - let (sum_rows, rest) = self.import_adt_and_rest(node_id, region_data.targets)?; + let (_, region_outputs, _) = self.get_func_type( + region_data + .signature + .ok_or_else(|| error_uninferred!("region signature"))?, + )?; + let (sum_rows, rest) = self.import_adt_and_rest(node_id, region_outputs)?; let (just_inputs, just_outputs) = { let mut sum_rows = sum_rows.into_iter(); - // NOTE: This can not fail since else `import_adt_and_rest` would have failed before. - let term = region_data.targets[0].r#type.unwrap(); - let Some(just_inputs) = sum_rows.next() else { - return Err(model::ModelError::TypeError(term).into()); + return Err(model::ModelError::TypeError(region_outputs).into()); }; let Some(just_outputs) = sum_rows.next() else { - return Err(model::ModelError::TypeError(term).into()); + return Err(model::ModelError::TypeError(region_outputs).into()); }; (just_inputs, just_outputs) @@ -682,9 +673,13 @@ impl<'a> Context<'a> { ) -> Result { let node_data = self.get_node(node_id)?; debug_assert_eq!(node_data.operation, model::Operation::Conditional); - - let (sum_rows, other_inputs) = self.import_adt_and_rest(node_id, node_data.inputs)?; - let outputs = self.get_port_types(node_data.outputs)?; + let (inputs, outputs, _) = self.get_func_type( + node_data + .signature + .ok_or_else(|| error_uninferred!("node signature"))?, + )?; + let (sum_rows, other_inputs) = self.import_adt_and_rest(node_id, inputs)?; + let outputs = self.import_type_row(outputs)?; let optype = OpType::Conditional(Conditional { sum_rows, @@ -697,10 +692,11 @@ impl<'a> Context<'a> { for region in node_data.regions { let region_data = self.get_region(*region)?; - - let source_types = self.get_port_types(region_data.sources)?; - let target_types = self.get_port_types(region_data.targets)?; - let signature = FuncTypeBase::new(source_types, target_types); + let signature = self.import_func_type( + region_data + .signature + .ok_or_else(|| error_uninferred!("region signature"))?, + )?; let case_node = self .hugr @@ -712,114 +708,6 @@ impl<'a> Context<'a> { Ok(node) } - /// Create the entry block for a control flow region. - /// - /// Since the core hugr does not have explicit entry blocks yet, we create a dataflow block - /// that simply forwards its inputs to its outputs. - fn make_entry_node( - &mut self, - parent: Node, - parent_id: model::NodeId, - ports: &'a [model::Port<'a>], - ) -> Result { - let types = { - let [port] = ports else { - return Err(model::ModelError::InvalidRegions(parent_id).into()); - }; - - let Some(port_type) = port.r#type else { - return Err(error_uninferred!("port type")); - }; - - let model::Term::Control { values: types } = self.get_term(port_type)? else { - return Err(model::ModelError::TypeError(port_type).into()); - }; - - self.import_type_row(*types)? - }; - - let node = self.hugr.add_node_with_parent( - parent, - OpType::DataflowBlock(DataflowBlock { - inputs: types.clone(), - other_outputs: TypeRow::default(), - sum_rows: vec![types.clone()], - extension_delta: ExtensionSet::default(), - }), - ); - - self.record_links(node, Direction::Outgoing, ports); - - let node_input = self.hugr.add_node_with_parent( - node, - OpType::Input(Input { - types: types.clone(), - }), - ); - - let node_output = self.hugr.add_node_with_parent( - node, - OpType::Output(Output { - types: vec![Type::new_sum([types.clone()])].into(), - }), - ); - - let node_tag = self.hugr.add_node_with_parent( - node, - OpType::Tag(Tag { - tag: 0, - variants: vec![types], - }), - ); - - // Connect the input node to the tag node - let input_outputs = self.hugr.node_outputs(node_input); - let tag_inputs = self.hugr.node_inputs(node_tag); - - for (a, b) in input_outputs.zip(tag_inputs) { - self.hugr.connect(node_input, a, node_tag, b); - } - - // Connect the tag node to the output node - let tag_outputs = self.hugr.node_outputs(node_tag); - let output_inputs = self.hugr.node_inputs(node_output); - - for (a, b) in tag_outputs.zip(output_inputs) { - self.hugr.connect(node_tag, a, node_output, b); - } - - Ok(node) - } - - fn make_exit_node( - &mut self, - parent: Node, - parent_id: model::NodeId, - ports: &'a [model::Port<'a>], - ) -> Result { - let cfg_outputs = { - let [port] = ports else { - return Err(model::ModelError::InvalidRegions(parent_id).into()); - }; - - let Some(port_type) = port.r#type else { - return Err(error_uninferred!("port type")); - }; - - let model::Term::Control { values: types } = self.get_term(port_type)? else { - return Err(model::ModelError::TypeError(port_type).into()); - }; - - self.import_type_row(*types)? - }; - - let node = self - .hugr - .add_node_with_parent(parent, OpType::ExitBlock(ExitBlock { cfg_outputs })); - self.record_links(node, Direction::Incoming, ports); - Ok(node) - } - fn import_cfg_region( &mut self, node_id: model::NodeId, @@ -832,13 +720,105 @@ impl<'a> Context<'a> { return Err(model::ModelError::InvalidRegions(node_id).into()); } - self.make_entry_node(node, node_id, region_data.sources)?; + let (region_source, region_targets, _) = self.get_func_type( + region_data + .signature + .ok_or_else(|| error_uninferred!("region signature"))?, + )?; + + let region_source_types = self.import_closed_list(region_source)?; + let region_target_types = self.import_closed_list(region_targets)?; + + // Create the entry node for the control flow region. + // Since the core hugr does not have explicit entry blocks yet, we create a dataflow block + // that simply forwards its inputs to its outputs. + { + let types = { + let [ctrl_type] = region_source_types.as_slice() else { + return Err(model::ModelError::TypeError(region_source).into()); + }; + + let model::Term::Control { values: types } = self.get_term(*ctrl_type)? else { + return Err(model::ModelError::TypeError(*ctrl_type).into()); + }; + + self.import_type_row(*types)? + }; + + let entry = self.hugr.add_node_with_parent( + node, + OpType::DataflowBlock(DataflowBlock { + inputs: types.clone(), + other_outputs: TypeRow::default(), + sum_rows: vec![types.clone()], + extension_delta: ExtensionSet::default(), + }), + ); + + self.record_links(entry, Direction::Outgoing, region_data.sources); + + let node_input = self.hugr.add_node_with_parent( + entry, + OpType::Input(Input { + types: types.clone(), + }), + ); + + let node_output = self.hugr.add_node_with_parent( + entry, + OpType::Output(Output { + types: vec![Type::new_sum([types.clone()])].into(), + }), + ); + + let node_tag = self.hugr.add_node_with_parent( + entry, + OpType::Tag(Tag { + tag: 0, + variants: vec![types], + }), + ); + + // Connect the input node to the tag node + let input_outputs = self.hugr.node_outputs(node_input); + let tag_inputs = self.hugr.node_inputs(node_tag); + + for (a, b) in input_outputs.zip(tag_inputs) { + self.hugr.connect(node_input, a, node_tag, b); + } + + // Connect the tag node to the output node + let tag_outputs = self.hugr.node_outputs(node_tag); + let output_inputs = self.hugr.node_inputs(node_output); + + for (a, b) in tag_outputs.zip(output_inputs) { + self.hugr.connect(node_tag, a, node_output, b); + } + } for child in region_data.children { self.import_node(*child, node)?; } - self.make_exit_node(node, node_id, region_data.targets)?; + // Create the exit node for the control flow region. + { + let cfg_outputs = { + let [ctrl_type] = region_target_types.as_slice() else { + return Err(model::ModelError::TypeError(region_targets).into()); + }; + + let model::Term::Control { values: types } = self.get_term(*ctrl_type)? else { + return Err(model::ModelError::TypeError(*ctrl_type).into()); + }; + + self.import_type_row(*types)? + }; + + let exit = self + .hugr + .add_node_with_parent(node, OpType::ExitBlock(ExitBlock { cfg_outputs })); + self.record_links(exit, Direction::Incoming, region_data.targets); + } Ok(()) } @@ -855,14 +835,20 @@ impl<'a> Context<'a> { return Err(model::ModelError::InvalidRegions(node_id).into()); }; let region_data = self.get_region(*region)?; - let inputs = self.get_port_types(region_data.sources)?; - let (sum_rows, other_outputs) = self.import_adt_and_rest(node_id, region_data.targets)?; + let (inputs, outputs, extensions) = self.get_func_type( + region_data + .signature + .ok_or_else(|| error_uninferred!("region signature"))?, + )?; + let inputs = self.import_type_row(inputs)?; + let (sum_rows, other_outputs) = self.import_adt_and_rest(node_id, outputs)?; + let extension_delta = self.import_extension_set(extensions)?; let optype = OpType::DataflowBlock(DataflowBlock { inputs, other_outputs, sum_rows, - extension_delta: ExtensionSet::new(), + extension_delta, }); let node = self.make_node(node_id, optype, parent)?; @@ -1099,24 +1085,28 @@ impl<'a> Context<'a> { } } + fn get_func_type( + &mut self, + term_id: model::TermId, + ) -> Result<(model::TermId, model::TermId, model::TermId), ImportError> { + match self.get_term(term_id)? { + model::Term::FuncType { + inputs, + outputs, + extensions, + } => Ok((*inputs, *outputs, *extensions)), + _ => Err(model::ModelError::TypeError(term_id).into()), + } + } + fn import_func_type( &mut self, term_id: model::TermId, ) -> Result, ImportError> { - let term = self.get_term(term_id)?; - - let model::Term::FuncType { - inputs, - outputs, - extensions, - } = term - else { - return Err(model::ModelError::TypeError(term_id).into()); - }; - - let inputs = self.import_type_row::(*inputs)?; - let outputs = self.import_type_row::(*outputs)?; - let extensions = self.import_extension_set(*extensions)?; + let (inputs, outputs, extensions) = self.get_func_type(term_id)?; + let inputs = self.import_type_row::(inputs)?; + let outputs = self.import_type_row::(outputs)?; + let extensions = self.import_extension_set(extensions)?; Ok(FuncTypeBase::new(inputs, outputs).with_extension_delta(extensions)) } @@ -1124,6 +1114,9 @@ impl<'a> Context<'a> { &mut self, mut term_id: model::TermId, ) -> Result, ImportError> { + // PERFORMANCE: We currently allocate a Vec here to collect list items + // into, in order to handle the case where the tail of the list is another + // list. We should avoid this. let mut list_items = Vec::new(); loop { @@ -1146,6 +1139,16 @@ impl<'a> Context<'a> { Ok(list_items) } + fn import_type_rows( + &mut self, + term_id: model::TermId, + ) -> Result>, ImportError> { + self.import_closed_list(term_id)? + .iter() + .map(|row| self.import_type_row::(*row)) + .collect() + } + fn import_type_row( &mut self, term_id: model::TermId, @@ -1159,27 +1162,25 @@ impl<'a> Context<'a> { Ok(items.into()) } - fn import_type_rows( + fn import_custom_name( &mut self, - term_id: model::TermId, - ) -> Result>, ImportError> { - let items = self - .import_closed_list(term_id)? - .iter() - .map(|item| self.import_type_row(*item)) - .collect::, _>>()?; - Ok(items) - } - - fn import_custom_name(&self, symbol: &'a str) -> Result<(ExtensionId, SmolStr), ImportError> { - let qualified_name = ExtensionId::new(symbol) - .map_err(|_| model::ModelError::MalformedName(symbol.to_smolstr()))?; - - let (extension, id) = qualified_name - .split_last() - .ok_or_else(|| model::ModelError::MalformedName(symbol.to_smolstr()))?; - - Ok((extension, id)) + symbol: &'a str, + ) -> Result<(ExtensionId, SmolStr), ImportError> { + use std::collections::hash_map::Entry; + match self.custom_name_cache.entry(symbol) { + Entry::Occupied(occupied_entry) => Ok(occupied_entry.get().clone()), + Entry::Vacant(vacant_entry) => { + let qualified_name = ExtensionId::new(symbol) + .map_err(|_| model::ModelError::MalformedName(symbol.to_smolstr()))?; + + let (extension, id) = qualified_name + .split_last() + .ok_or_else(|| model::ModelError::MalformedName(symbol.to_smolstr()))?; + + vacant_entry.insert((extension.clone(), id.clone())); + Ok((extension, id)) + } + } } } diff --git a/hugr-core/tests/fixtures/model-add.edn b/hugr-core/tests/fixtures/model-add.edn index 4749dc0c7..f7783cb41 100644 --- a/hugr-core/tests/fixtures/model-add.edn +++ b/hugr-core/tests/fixtures/model-add.edn @@ -5,6 +5,9 @@ [(@ arithmetic.int.types.int)] (ext) (dfg - [(%0 (@ arithmetic.int.types.int)) (%1 (@ arithmetic.int.types.int))] - [(%2 (@ arithmetic.int.types.int))] - ((@ arithmetic.int.iadd) [(%0 (@ arithmetic.int.types.int)) (%1 (@ arithmetic.int.types.int))] [(%2 (@ arithmetic.int.types.int))]))) + [%0 %1] + [%2] + (signature (fn [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) + ((@ arithmetic.int.iadd) [%0 %1] [%2] + (signature (fn [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) + ))) diff --git a/hugr-core/tests/fixtures/model-call.edn b/hugr-core/tests/fixtures/model-call.edn index 5918543bf..d463e391a 100644 --- a/hugr-core/tests/fixtures/model-call.edn +++ b/hugr-core/tests/fixtures/model-call.edn @@ -10,14 +10,15 @@ [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int) (meta doc.title "Caller") (meta doc.description "This defines a function that calls the function which we declared earlier.") - (dfg - [(%3 (@ arithmetic.int.types.int))] - [(%4 (@ arithmetic.int.types.int))] - (call (@ example.callee (ext)) [(%3 (@ arithmetic.int.types.int))] [(%4 (@ arithmetic.int.types.int))]))) + (dfg [%3] [%4] + (signature (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) + (call (@ example.callee (ext)) [%3] [%4] + (signature (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))))) (define-func example.load - [] [(fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))] (ext) + [] [(fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int))] (ext) (dfg [] - [(%5 (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))] - (load-func (@ example.caller) [] [(%5 (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int)))]))) + [%5] + (signature (fn [] [(fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int))] (ext))) + (load-func (@ example.caller) [] [%5]))) diff --git a/hugr-core/tests/fixtures/model-cfg.edn b/hugr-core/tests/fixtures/model-cfg.edn index 54e895a6b..92ae19441 100644 --- a/hugr-core/tests/fixtures/model-cfg.edn +++ b/hugr-core/tests/fixtures/model-cfg.edn @@ -3,9 +3,15 @@ (define-func example.cfg (forall ?a type) [?a] [?a] (ext) - (dfg [(%0 ?a)] [(%1 ?a)] - (cfg [(%0 ?a)] [(%1 ?a)] - (cfg [(%2 (ctrl [?a]))] [(%4 (ctrl [?a]))] - (block [(%2 (ctrl [?a]))] [(%4 (ctrl [?a]))] - (dfg [(%5 ?a)] [(%6 (adt [[?a]]))] - (tag 0 [(%5 ?a)] [(%6 (adt [[?a]]))]))))))) + (dfg [%0] [%1] + (signature (fn [?a] [?a] (ext))) + (cfg [%0] [%1] + (signature (fn [?a] [?a] (ext))) + (cfg [%2] [%4] + (signature (fn [(ctrl [?a])] [(ctrl [?a])] (ext))) + (block [%2] [%4] + (signature (fn [(ctrl [?a])] [(ctrl [?a])] (ext))) + (dfg [%5] [%6] + (signature (fn [?a] [(adt [[?a]])] (ext))) + (tag 0 [%5] [%6] + (signature (fn [?a] [(adt [[?a]])] (ext)))))))))) diff --git a/hugr-core/tests/fixtures/model-cond.edn b/hugr-core/tests/fixtures/model-cond.edn index 04304d108..aa1ecef7d 100644 --- a/hugr-core/tests/fixtures/model-cond.edn +++ b/hugr-core/tests/fixtures/model-cond.edn @@ -1,21 +1,15 @@ (hugr 0) - (define-func example.cond [(adt [[] []]) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext) - (dfg - [(%0 (adt [[] []])) (%1 (@ arithmetic.int.types.int))] - [(%2 (@ arithmetic.int.types.int))] - (cond - [(%0 (adt [[] []])) (%1 (@ arithmetic.int.types.int))] - [(%2 (@ arithmetic.int.types.int))] - (dfg - [(%3 (@ arithmetic.int.types.int))] - [(%3 (@ arithmetic.int.types.int))]) - (dfg - [(%4 (@ arithmetic.int.types.int))] - [(%5 (@ arithmetic.int.types.int))] - ((@ arithmetic.int.ineg) - [(%4 (@ arithmetic.int.types.int))] - [(%5 (@ arithmetic.int.types.int))]))))) + (dfg [%0 %1] [%2] + (signature (fn [(adt [[] []]) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) + (cond [%0 %1] [%2] + (signature (fn [(adt [[] []]) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) + (dfg [%3] [%3] + (signature (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))) + (dfg [%4] [%5] + (signature (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) + ((@ arithmetic.int.ineg) [%4] [%5] + (signature (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))))))) diff --git a/hugr-core/tests/fixtures/model-loop.edn b/hugr-core/tests/fixtures/model-loop.edn index d35e60578..5df4b2a87 100644 --- a/hugr-core/tests/fixtures/model-loop.edn +++ b/hugr-core/tests/fixtures/model-loop.edn @@ -3,7 +3,11 @@ (define-func example.loop (forall ?a type) [?a] [?a] (ext) - (dfg [(%0 ?a)] [(%1 ?a)] - (tail-loop [(%0 ?a)] [(%1 ?a)] - (dfg [(%2 ?a)] [(%3 (adt [[?a] [?a]]))] - (tag 0 [(%2 ?a)] [(%3 (adt [[?a] [?a]]))]))))) + (dfg [%0] [%1] + (signature (fn [?a] [?a] (ext))) + (tail-loop [%0] [%1] + (signature (fn [?a] [?a] (ext))) + (dfg [%2] [%3] + (signature (fn [?a] [(adt [[?a] [?a]])] (ext))) + (tag 0 [%2] [%3] + (signature (fn [?a] [(adt [[?a] [?a]])] (ext)))))))) diff --git a/hugr-core/tests/fixtures/model-params.edn b/hugr-core/tests/fixtures/model-params.edn index c89dd158f..171860cae 100644 --- a/hugr-core/tests/fixtures/model-params.edn +++ b/hugr-core/tests/fixtures/model-params.edn @@ -5,4 +5,5 @@ (forall ?a type) (forall ?b type) [?a ?b] [?b ?a] (ext) - (dfg [(%a ?a) (%b ?b)] [(%b ?b) (%a ?a)])) + (dfg [%a %b] [%b %a] + (signature (fn [?a ?b] [?b ?a] (ext))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_add.snap b/hugr-core/tests/snapshots/model__roundtrip_add.snap index 262891119..b7de139fe 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_add.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_add.snap @@ -9,8 +9,15 @@ expression: "roundtrip(include_str!(\"fixtures/model-add.edn\"))" [(@ arithmetic.int.types.int)] (ext) (dfg - [(%0 (@ arithmetic.int.types.int)) (%1 (@ arithmetic.int.types.int))] - [(%2 (@ arithmetic.int.types.int))] - (arithmetic.int.iadd - [(%0 (@ arithmetic.int.types.int)) (%1 (@ arithmetic.int.types.int))] - [(%2 (@ arithmetic.int.types.int))]))) + [%0 %1] [%2] + (signature + (fn + [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext))) + ((@ arithmetic.int.iadd) [%0 %1] [%2] + (signature + (fn + [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext arithmetic.int)))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_call.snap b/hugr-core/tests/snapshots/model__roundtrip_call.snap index 40a8ac10b..a799a0944 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_call.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_call.snap @@ -15,29 +15,41 @@ expression: "roundtrip(include_str!(\"fixtures/model-call.edn\"))" [(@ arithmetic.int.types.int)] (ext arithmetic.int) (dfg - [(%0 (@ arithmetic.int.types.int))] - [(%1 (@ arithmetic.int.types.int))] - (call - (@ example.callee (ext)) - [(%0 (@ arithmetic.int.types.int))] - [(%1 (@ arithmetic.int.types.int))]))) + [%0] [%1] + (signature + (fn + [(@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext arithmetic.int))) + (call (@ example.callee (ext)) [%0] [%1] + (signature + (fn + [(@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext arithmetic.int)))))) (define-func example.load [] - [(fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))] + [(fn + [(@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext arithmetic.int))] (ext) (dfg - [] - [(%2 - (fn - [(@ arithmetic.int.types.int)] - [(@ arithmetic.int.types.int)] - (ext)))] - (load-func - (@ example.caller) - [] - [(%2 - (fn + (signature + (fn + [] + [(fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] - (ext arithmetic.int)))]))) + (ext arithmetic.int))] + (ext))) + (load-func (@ example.caller) + (signature + (fn + [] + [(fn + [(@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext arithmetic.int))] + (ext)))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap index cbcbe3643..f3c0f0acc 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap @@ -8,19 +8,22 @@ expression: "roundtrip(include_str!(\"fixtures/model-cfg.edn\"))" (forall ?0 type) [?0] [?0] (ext) (dfg - [(%0 ?0)] - [(%1 ?0)] - (cfg [(%0 ?0)] [(%1 ?0)] + [%0] [%1] + (signature (fn [?0] [?0] (ext))) + (cfg [%0] [%1] + (signature (fn [?0] [?0] (ext))) (cfg - [(%2 (ctrl [?0]))] - [(%6 (ctrl [?0]))] - (block [(%2 (ctrl [?0]))] [(%3 (ctrl [?0]))] + [%2] [%8] + (signature (fn [?0] [?0] (ext))) + (block [%2] [%5] + (signature (fn [(ctrl [?0])] [(ctrl [?0 . []])] (ext))) (dfg - [(%4 ?0)] - [(%5 (adt [[?0]]))] - (tag 0 [(%4 ?0)] [(%5 (adt [[?0]]))]))) - (block [(%3 (ctrl [?0]))] [(%6 (ctrl [?0]))] + [%3] [%4] + (signature (fn [?0] [(adt [[?0]])] (ext))) + (tag 0 [%3] [%4] (signature (fn [?0] [(adt [[?0]])] (ext)))))) + (block [%5] [%8] + (signature (fn [(ctrl [?0])] [(ctrl [?0 . []])] (ext))) (dfg - [(%7 ?0)] - [(%8 (adt [[?0]]))] - (tag 0 [(%7 ?0)] [(%8 (adt [[?0]]))]))))))) + [%6] [%7] + (signature (fn [?0] [(adt [[?0]])] (ext))) + (tag 0 [%6] [%7] (signature (fn [?0] [(adt [[?0]])] (ext)))))))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_cond.snap b/hugr-core/tests/snapshots/model__roundtrip_cond.snap index 0fdbc2f91..fe55e965f 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cond.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cond.snap @@ -9,17 +9,36 @@ expression: "roundtrip(include_str!(\"fixtures/model-cond.edn\"))" [(@ arithmetic.int.types.int)] (ext) (dfg - [(%0 (adt [[] []])) (%1 (@ arithmetic.int.types.int))] - [(%2 (@ arithmetic.int.types.int))] + [%0 %1] [%2] + (signature + (fn + [(adt [[] []]) (@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext))) (cond - [(%0 (adt [[] []])) (%1 (@ arithmetic.int.types.int))] - [(%2 (@ arithmetic.int.types.int))] + [%0 %1] [%2] + (signature + (fn + [(adt [[] []]) (@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext))) (dfg - [(%3 (@ arithmetic.int.types.int))] - [(%3 (@ arithmetic.int.types.int))]) + [%3] [%3] + (signature + (fn + [(@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext)))) (dfg - [(%4 (@ arithmetic.int.types.int))] - [(%5 (@ arithmetic.int.types.int))] - (arithmetic.int.ineg - [(%4 (@ arithmetic.int.types.int))] - [(%5 (@ arithmetic.int.types.int))]))))) + [%4] [%5] + (signature + (fn + [(@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext))) + ((@ arithmetic.int.ineg) [%4] [%5] + (signature + (fn + [(@ arithmetic.int.types.int)] + [(@ arithmetic.int.types.int)] + (ext arithmetic.int)))))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_loop.snap b/hugr-core/tests/snapshots/model__roundtrip_loop.snap index eb1debbfd..a513318ae 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_loop.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_loop.snap @@ -8,12 +8,12 @@ expression: "roundtrip(include_str!(\"fixtures/model-loop.edn\"))" (forall ?0 type) [?0] [?0] (ext) (dfg - [(%0 ?0)] - [(%1 ?0)] + [%0] [%1] + (signature (fn [?0] [?0] (ext))) (tail-loop - [(%0 ?0)] - [(%1 ?0)] + [%0] [%1] + (signature (fn [?0] [?0] (ext))) (dfg - [(%2 ?0)] - [(%3 (adt [[?0] [?0]]))] - (tag 0 [(%2 ?0)] [(%3 (adt [[?0] [?0]]))]))))) + [%2] [%3] + (signature (fn [?0] [(adt [[?0] [?0]])] (ext))) + (tag 0 [%2] [%3] (signature (fn [?0] [(adt [[?0] [?0]])] (ext)))))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_params.snap b/hugr-core/tests/snapshots/model__roundtrip_params.snap index b146bc706..ab2b98d8c 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_params.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_params.snap @@ -8,4 +8,4 @@ expression: "roundtrip(include_str!(\"fixtures/model-params.edn\"))" (forall ?0 type) (forall ?1 type) [?0 ?1] [?1 ?0] (ext) - (dfg [(%0 ?0) (%1 ?1)] [(%1 ?1) (%0 ?0)])) + (dfg [%0 %1] [%1 %0] (signature (fn [?0 ?1] [?1 ?0] (ext))))) diff --git a/hugr-model/Cargo.toml b/hugr-model/Cargo.toml index 7f0396a24..afd67cdde 100644 --- a/hugr-model/Cargo.toml +++ b/hugr-model/Cargo.toml @@ -14,6 +14,7 @@ license.workspace = true [dependencies] bumpalo = { workspace = true, features = ["collections"] } +capnp = "0.20.1" fxhash.workspace = true indexmap.workspace = true pest = "2.7.12" @@ -24,3 +25,9 @@ thiserror.workspace = true [lints] workspace = true + +[build-dependencies] +capnpc = "0.20.0" + +[dev-dependencies] +pretty_assertions = "1.4.1" diff --git a/hugr-model/build.rs b/hugr-model/build.rs new file mode 100644 index 000000000..d4eec4afd --- /dev/null +++ b/hugr-model/build.rs @@ -0,0 +1,12 @@ +//! Build scripts for `hugr-model`. + +/// Build the capnp schema files. +fn main() { + capnpc::CompilerCommand::new() + .src_prefix("capnp") + .file("capnp/hugr-v0.capnp") + .run() + .expect("compiling schema"); + + println!("cargo:rerun-if-changed=capnp/hugr-v0.capnp"); +} diff --git a/hugr-model/capnp/hugr-v0.capnp b/hugr-model/capnp/hugr-v0.capnp new file mode 100644 index 000000000..296468e8e --- /dev/null +++ b/hugr-model/capnp/hugr-v0.capnp @@ -0,0 +1,190 @@ +@0xe02b32c528509601; + +# The id of a `Term`. +using TermId = UInt32; + +# Either `0` or the id of a `Term` incremented by one. +using OptionalTermId = UInt32; + +# The id of a `Region`. +using RegionId = UInt32; + +# The id of a `Node`. +using NodeId = UInt32; + +# The id of a `Link`. +using LinkId = UInt32; + +struct Module { + root @0 :RegionId; + nodes @1 :List(Node); + regions @2 :List(Region); + terms @3 :List(Term); +} + +struct Node { + operation @0 :Operation; + inputs @1 :List(LinkRef); + outputs @2 :List(LinkRef); + params @3 :List(TermId); + regions @4 :List(RegionId); + meta @5 :List(MetaItem); + signature @6 :OptionalTermId; +} + +struct Operation { + union { + invalid @0 :Void; + dfg @1 :Void; + cfg @2 :Void; + block @3 :Void; + funcDefn @4 :FuncDecl; + funcDecl @5 :FuncDecl; + aliasDefn @6 :AliasDefn; + aliasDecl @7 :AliasDecl; + custom @8 :GlobalRef; + customFull @9 :GlobalRef; + tag @10 :UInt16; + tailLoop @11 :Void; + conditional @12 :Void; + callFunc @13 :TermId; + loadFunc @14 :TermId; + } + + struct FuncDefn { + name @0 :Text; + params @1 :List(Param); + signature @2 :TermId; + } + + struct FuncDecl { + name @0 :Text; + params @1 :List(Param); + signature @2 :TermId; + } + + struct AliasDefn { + name @0 :Text; + params @1 :List(Param); + type @2 :TermId; + value @3 :TermId; + } + + struct AliasDecl { + name @0 :Text; + params @1 :List(Param); + type @2 :TermId; + } +} + +struct Region { + kind @0 :RegionKind; + sources @1 :List(LinkRef); + targets @2 :List(LinkRef); + children @3 :List(NodeId); + meta @4 :List(MetaItem); + signature @5 :OptionalTermId; +} + +enum RegionKind { + dataFlow @0; + controlFlow @1; +} + +struct MetaItem { + name @0 :Text; + value @1 :UInt32; +} + +struct LinkRef { + union { + id @0 :LinkId; + named @1 :Text; + } +} + +struct GlobalRef { + union { + node @0 :NodeId; + named @1 :Text; + } +} + +struct LocalRef { + union { + direct :group { + index @0 :UInt16; + node @1 :NodeId; + } + named @2 :Text; + } +} + +struct Term { + union { + wildcard @0 :Void; + runtimeType @1 :Void; + staticType @2 :Void; + constraint @3 :Void; + variable @4 :LocalRef; + apply @5 :Apply; + applyFull @6 :ApplyFull; + quote @7 :TermId; + list @8 :ListTerm; + listType @9 :TermId; + string @10 :Text; + stringType @11 :Void; + nat @12 :UInt64; + natType @13 :Void; + extSet @14 :ExtSet; + extSetType @15 :Void; + adt @16 :TermId; + funcType @17 :FuncType; + control @18 :TermId; + controlType @19 :Void; + } + + struct Apply { + global @0 :GlobalRef; + args @1 :List(TermId); + } + + struct ApplyFull { + global @0 :GlobalRef; + args @1 :List(TermId); + } + + struct ListTerm { + items @0 :List(TermId); + tail @1 :OptionalTermId; + } + + struct ExtSet { + extensions @0 :List(Text); + rest @1 :OptionalTermId; + } + + struct FuncType { + inputs @0 :TermId; + outputs @1 :TermId; + extensions @2 :TermId; + } +} + +struct Param { + union { + implicit @0 :Implicit; + explicit @1 :Explicit; + constraint @2 :TermId; + } + + struct Implicit { + name @0 :Text; + type @1 :TermId; + } + + struct Explicit { + name @0 :Text; + type @1 :TermId; + } +} diff --git a/hugr-model/src/lib.rs b/hugr-model/src/lib.rs index 6c161ca37..c0bf1536f 100644 --- a/hugr-model/src/lib.rs +++ b/hugr-model/src/lib.rs @@ -3,3 +3,7 @@ //! all its associated information in a form that can be stored on disk. The data structures //! are not designed for efficient traversal or modification, but for simplicity and serialization. pub mod v0; + +pub(crate) mod hugr_v0_capnp { + include!(concat!(env!("OUT_DIR"), "/hugr_v0_capnp.rs")); +} diff --git a/hugr-model/src/v0/binary/mod.rs b/hugr-model/src/v0/binary/mod.rs new file mode 100644 index 000000000..816c64fb2 --- /dev/null +++ b/hugr-model/src/v0/binary/mod.rs @@ -0,0 +1,6 @@ +//! The HUGR binary representation. +mod read; +mod write; + +pub use read::read_from_slice; +pub use write::write_to_vec; diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs new file mode 100644 index 000000000..082d7b3db --- /dev/null +++ b/hugr-model/src/v0/binary/read.rs @@ -0,0 +1,345 @@ +use crate::hugr_v0_capnp as hugr_capnp; +use crate::v0 as model; +use bumpalo::collections::Vec as BumpVec; +use bumpalo::Bump; + +type ReadResult = Result; + +/// Read a hugr module from a byte slice. +pub fn read_from_slice<'a>(slice: &[u8], bump: &'a Bump) -> ReadResult> { + let reader = + capnp::serialize_packed::read_message(slice, capnp::message::ReaderOptions::new())?; + let root = reader.get_root::()?; + read_module(bump, root) +} + +/// Read a list of structs from a reader into a slice allocated through the bump allocator. +macro_rules! read_list { + ($bump:expr, $reader:expr, $get:ident, $read:expr) => {{ + let mut __list_reader = $reader.$get()?; + let mut __list = BumpVec::with_capacity_in(__list_reader.len() as _, $bump); + for __item_reader in __list_reader.iter() { + __list.push($read($bump, __item_reader)?); + } + __list.into_bump_slice() + }}; +} + +/// Read a list of scalars from a reader into a slice allocated through the bump allocator. +macro_rules! read_scalar_list { + ($bump:expr, $reader:expr, $get:ident, $wrap:path) => {{ + let mut __list_reader = $reader.$get()?; + let mut __list = BumpVec::with_capacity_in(__list_reader.len() as _, $bump); + for __item in __list_reader.iter() { + __list.push($wrap(__item)); + } + __list.into_bump_slice() + }}; +} + +fn read_module<'a>( + bump: &'a Bump, + reader: hugr_capnp::module::Reader, +) -> ReadResult> { + let root = model::RegionId(reader.get_root()); + + let nodes = reader + .get_nodes()? + .iter() + .map(|r| read_node(bump, r)) + .collect::>()?; + + let regions = reader + .get_regions()? + .iter() + .map(|r| read_region(bump, r)) + .collect::>()?; + + let terms = reader + .get_terms()? + .iter() + .map(|r| read_term(bump, r)) + .collect::>()?; + + Ok(model::Module { + root, + nodes, + regions, + terms, + }) +} + +fn read_node<'a>(bump: &'a Bump, reader: hugr_capnp::node::Reader) -> ReadResult> { + let operation = read_operation(bump, reader.get_operation()?)?; + let inputs = read_list!(bump, reader, get_inputs, read_link_ref); + let outputs = read_list!(bump, reader, get_outputs, read_link_ref); + let params = read_scalar_list!(bump, reader, get_params, model::TermId); + let regions = read_scalar_list!(bump, reader, get_regions, model::RegionId); + let meta = read_list!(bump, reader, get_meta, read_meta_item); + let signature = reader.get_signature().checked_sub(1).map(model::TermId); + + Ok(model::Node { + operation, + inputs, + outputs, + params, + regions, + meta, + signature, + }) +} + +fn read_local_ref<'a>( + bump: &'a Bump, + reader: hugr_capnp::local_ref::Reader, +) -> ReadResult> { + use hugr_capnp::local_ref::Which; + Ok(match reader.which()? { + Which::Direct(reader) => { + let index = reader.get_index(); + let node = model::NodeId(reader.get_node()); + model::LocalRef::Index(node, index) + } + Which::Named(name) => model::LocalRef::Named(bump.alloc_str(name?.to_str()?)), + }) +} + +fn read_global_ref<'a>( + bump: &'a Bump, + reader: hugr_capnp::global_ref::Reader, +) -> ReadResult> { + use hugr_capnp::global_ref::Which; + Ok(match reader.which()? { + Which::Node(node) => model::GlobalRef::Direct(model::NodeId(node)), + Which::Named(name) => model::GlobalRef::Named(bump.alloc_str(name?.to_str()?)), + }) +} + +fn read_link_ref<'a>( + bump: &'a Bump, + reader: hugr_capnp::link_ref::Reader, +) -> ReadResult> { + use hugr_capnp::link_ref::Which; + Ok(match reader.which()? { + Which::Id(id) => model::LinkRef::Id(model::LinkId(id)), + Which::Named(name) => model::LinkRef::Named(bump.alloc_str(name?.to_str()?)), + }) +} + +fn read_operation<'a>( + bump: &'a Bump, + reader: hugr_capnp::operation::Reader, +) -> ReadResult> { + use hugr_capnp::operation::Which; + Ok(match reader.which()? { + Which::Invalid(()) => model::Operation::Invalid, + Which::Dfg(()) => model::Operation::Dfg, + Which::Cfg(()) => model::Operation::Cfg, + Which::Block(()) => model::Operation::Block, + Which::FuncDefn(reader) => { + let reader = reader?; + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let params = read_list!(bump, reader, get_params, read_param); + let signature = model::TermId(reader.get_signature()); + let decl = bump.alloc(model::FuncDecl { + name, + params, + signature, + }); + model::Operation::DefineFunc { decl } + } + Which::FuncDecl(reader) => { + let reader = reader?; + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let params = read_list!(bump, reader, get_params, read_param); + let signature = model::TermId(reader.get_signature()); + let decl = bump.alloc(model::FuncDecl { + name, + params, + signature, + }); + model::Operation::DeclareFunc { decl } + } + Which::AliasDefn(reader) => { + let reader = reader?; + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let params = read_list!(bump, reader, get_params, read_param); + let r#type = model::TermId(reader.get_type()); + let value = model::TermId(reader.get_value()); + let decl = bump.alloc(model::AliasDecl { + name, + params, + r#type, + }); + model::Operation::DefineAlias { decl, value } + } + Which::AliasDecl(reader) => { + let reader = reader?; + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let params = read_list!(bump, reader, get_params, read_param); + let r#type = model::TermId(reader.get_type()); + let decl = bump.alloc(model::AliasDecl { + name, + params, + r#type, + }); + model::Operation::DeclareAlias { decl } + } + Which::Custom(name) => model::Operation::Custom { + operation: read_global_ref(bump, name?)?, + }, + Which::CustomFull(name) => model::Operation::CustomFull { + operation: read_global_ref(bump, name?)?, + }, + Which::Tag(tag) => model::Operation::Tag { tag }, + Which::TailLoop(()) => model::Operation::TailLoop, + Which::Conditional(()) => model::Operation::Conditional, + Which::CallFunc(func) => model::Operation::CallFunc { + func: model::TermId(func), + }, + Which::LoadFunc(func) => model::Operation::LoadFunc { + func: model::TermId(func), + }, + }) +} + +fn read_region<'a>( + bump: &'a Bump, + reader: hugr_capnp::region::Reader, +) -> ReadResult> { + let kind = match reader.get_kind()? { + hugr_capnp::RegionKind::DataFlow => model::RegionKind::DataFlow, + hugr_capnp::RegionKind::ControlFlow => model::RegionKind::ControlFlow, + }; + + let sources = read_list!(bump, reader, get_sources, read_link_ref); + let targets = read_list!(bump, reader, get_targets, read_link_ref); + let children = read_scalar_list!(bump, reader, get_children, model::NodeId); + let meta = read_list!(bump, reader, get_meta, read_meta_item); + let signature = reader.get_signature().checked_sub(1).map(model::TermId); + + Ok(model::Region { + kind, + sources, + targets, + children, + meta, + signature, + }) +} + +fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult> { + use hugr_capnp::term::Which; + Ok(match reader.which()? { + Which::Wildcard(()) => model::Term::Wildcard, + Which::RuntimeType(()) => model::Term::Type, + Which::StaticType(()) => model::Term::StaticType, + Which::Constraint(()) => model::Term::Constraint, + Which::String(value) => model::Term::Str(bump.alloc_str(value?.to_str()?)), + Which::StringType(()) => model::Term::StrType, + Which::Nat(value) => model::Term::Nat(value), + Which::NatType(()) => model::Term::NatType, + Which::ExtSetType(()) => model::Term::ExtSetType, + Which::ControlType(()) => model::Term::ControlType, + Which::Variable(local_ref) => model::Term::Var(read_local_ref(bump, local_ref?)?), + + Which::Apply(reader) => { + let reader = reader?; + let global = read_global_ref(bump, reader.get_global()?)?; + let args = read_scalar_list!(bump, reader, get_args, model::TermId); + model::Term::Apply { global, args } + } + + Which::ApplyFull(reader) => { + let reader = reader?; + let global = read_global_ref(bump, reader.get_global()?)?; + let args = read_scalar_list!(bump, reader, get_args, model::TermId); + model::Term::ApplyFull { global, args } + } + + Which::Quote(r#type) => model::Term::Quote { + r#type: model::TermId(r#type), + }, + + Which::List(reader) => { + let reader = reader?; + let items = read_scalar_list!(bump, reader, get_items, model::TermId); + let tail = reader.get_tail().checked_sub(1).map(model::TermId); + model::Term::List { items, tail } + } + + Which::ListType(item_type) => model::Term::ListType { + item_type: model::TermId(item_type), + }, + + Which::ExtSet(reader) => { + let reader = reader?; + + let extensions = { + let extensions_reader = reader.get_extensions()?; + let mut extensions = BumpVec::with_capacity_in(extensions_reader.len() as _, bump); + for extension_reader in extensions_reader.iter() { + extensions.push(bump.alloc_str(extension_reader?.to_str()?) as &str); + } + extensions.into_bump_slice() + }; + + let rest = reader.get_rest().checked_sub(1).map(model::TermId); + model::Term::ExtSet { extensions, rest } + } + + Which::Adt(variants) => model::Term::Adt { + variants: model::TermId(variants), + }, + + Which::FuncType(reader) => { + let reader = reader?; + let inputs = model::TermId(reader.get_inputs()); + let outputs = model::TermId(reader.get_outputs()); + let extensions = model::TermId(reader.get_extensions()); + model::Term::FuncType { + inputs, + outputs, + extensions, + } + } + + Which::Control(values) => model::Term::Control { + values: model::TermId(values), + }, + }) +} + +fn read_meta_item<'a>( + bump: &'a Bump, + reader: hugr_capnp::meta_item::Reader, +) -> ReadResult> { + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let value = model::TermId(reader.get_value()); + Ok(model::MetaItem { name, value }) +} + +fn read_param<'a>( + bump: &'a Bump, + reader: hugr_capnp::param::Reader, +) -> ReadResult> { + use hugr_capnp::param::Which; + Ok(match reader.which()? { + Which::Implicit(reader) => { + let reader = reader?; + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let r#type = model::TermId(reader.get_type()); + model::Param::Implicit { name, r#type } + } + Which::Explicit(reader) => { + let reader = reader?; + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let r#type = model::TermId(reader.get_type()); + model::Param::Explicit { name, r#type } + } + Which::Constraint(constraint) => { + let constraint = model::TermId(constraint); + model::Param::Constraint { constraint } + } + }) +} diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs new file mode 100644 index 000000000..4e7ccc0bc --- /dev/null +++ b/hugr-model/src/v0/binary/write.rs @@ -0,0 +1,201 @@ +use crate::hugr_v0_capnp as hugr_capnp; +use crate::v0 as model; + +/// Write a list of items into a list builder. +macro_rules! write_list { + ($builder:expr, $init:ident, $write:expr, $list:expr) => { + let mut __list_builder = $builder.reborrow().$init($list.len() as _); + for (index, item) in $list.iter().enumerate() { + $write(__list_builder.reborrow().get(index as _), item); + } + }; +} + +/// Writes a module to a byte vector. +pub fn write_to_vec(module: &model::Module) -> Vec { + let mut message = capnp::message::Builder::new_default(); + let builder = message.init_root(); + write_module(builder, module); + + let mut output = Vec::new(); + let _ = capnp::serialize_packed::write_message(&mut output, &message); + output +} + +fn write_module(mut builder: hugr_capnp::module::Builder, module: &model::Module) { + builder.set_root(module.root.0); + write_list!(builder, init_nodes, write_node, module.nodes); + write_list!(builder, init_regions, write_region, module.regions); + write_list!(builder, init_terms, write_term, module.terms); +} + +fn write_node(mut builder: hugr_capnp::node::Builder, node: &model::Node) { + write_operation(builder.reborrow().init_operation(), &node.operation); + write_list!(builder, init_inputs, write_link_ref, node.inputs); + write_list!(builder, init_outputs, write_link_ref, node.outputs); + write_list!(builder, init_meta, write_meta_item, node.meta); + let _ = builder.set_params(model::TermId::unwrap_slice(node.params)); + let _ = builder.set_regions(model::RegionId::unwrap_slice(node.regions)); + builder.set_signature(node.signature.map_or(0, |t| t.0 + 1)); +} + +fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &model::Operation) { + match operation { + model::Operation::Dfg => builder.set_dfg(()), + model::Operation::Cfg => builder.set_cfg(()), + model::Operation::Block => builder.set_block(()), + model::Operation::TailLoop => builder.set_tail_loop(()), + model::Operation::Conditional => builder.set_conditional(()), + model::Operation::Tag { tag } => builder.set_tag(*tag), + model::Operation::Custom { operation } => { + write_global_ref(builder.init_custom(), operation) + } + model::Operation::CustomFull { operation } => { + write_global_ref(builder.init_custom_full(), operation) + } + model::Operation::CallFunc { func } => builder.set_call_func(func.0), + model::Operation::LoadFunc { func } => builder.set_load_func(func.0), + + model::Operation::DefineFunc { decl } => { + let mut builder = builder.init_func_defn(); + builder.set_name(decl.name); + write_list!(builder, init_params, write_param, decl.params); + builder.set_signature(decl.signature.0); + } + model::Operation::DeclareFunc { decl } => { + let mut builder = builder.init_func_decl(); + builder.set_name(decl.name); + write_list!(builder, init_params, write_param, decl.params); + builder.set_signature(decl.signature.0); + } + + model::Operation::DefineAlias { decl, value } => { + let mut builder = builder.init_alias_defn(); + builder.set_name(decl.name); + write_list!(builder, init_params, write_param, decl.params); + builder.set_type(decl.r#type.0); + builder.set_value(value.0); + } + model::Operation::DeclareAlias { decl } => { + let mut builder = builder.init_alias_decl(); + builder.set_name(decl.name); + write_list!(builder, init_params, write_param, decl.params); + builder.set_type(decl.r#type.0); + } + model::Operation::Invalid => builder.set_invalid(()), + } +} + +fn write_param(mut builder: hugr_capnp::param::Builder, param: &model::Param) { + match param { + model::Param::Implicit { name, r#type } => { + let mut builder = builder.init_implicit(); + builder.set_name(name); + builder.set_type(r#type.0); + } + model::Param::Explicit { name, r#type } => { + let mut builder = builder.init_explicit(); + builder.set_name(name); + builder.set_type(r#type.0); + } + model::Param::Constraint { constraint } => builder.set_constraint(constraint.0), + } +} + +fn write_global_ref(mut builder: hugr_capnp::global_ref::Builder, global_ref: &model::GlobalRef) { + match global_ref { + model::GlobalRef::Direct(node) => builder.set_node(node.0), + model::GlobalRef::Named(name) => builder.set_named(name), + } +} + +fn write_link_ref(mut builder: hugr_capnp::link_ref::Builder, link_ref: &model::LinkRef) { + match link_ref { + model::LinkRef::Id(id) => builder.set_id(id.0), + model::LinkRef::Named(name) => builder.set_named(name), + } +} + +fn write_local_ref(mut builder: hugr_capnp::local_ref::Builder, local_ref: &model::LocalRef) { + match local_ref { + model::LocalRef::Index(node, index) => { + let mut builder = builder.init_direct(); + builder.set_node(node.0); + builder.set_index(*index); + } + model::LocalRef::Named(name) => builder.set_named(name), + } +} + +fn write_meta_item(mut builder: hugr_capnp::meta_item::Builder, meta_item: &model::MetaItem) { + builder.set_name(meta_item.name); + builder.set_value(meta_item.value.0) +} + +fn write_region(mut builder: hugr_capnp::region::Builder, region: &model::Region) { + builder.set_kind(match region.kind { + model::RegionKind::DataFlow => hugr_capnp::RegionKind::DataFlow, + model::RegionKind::ControlFlow => hugr_capnp::RegionKind::ControlFlow, + }); + + write_list!(builder, init_sources, write_link_ref, region.sources); + write_list!(builder, init_targets, write_link_ref, region.targets); + let _ = builder.set_children(model::NodeId::unwrap_slice(region.children)); + write_list!(builder, init_meta, write_meta_item, region.meta); + builder.set_signature(region.signature.map_or(0, |t| t.0 + 1)); +} + +fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { + match term { + model::Term::Wildcard => builder.set_wildcard(()), + model::Term::Type => builder.set_runtime_type(()), + model::Term::StaticType => builder.set_static_type(()), + model::Term::Constraint => builder.set_constraint(()), + model::Term::Var(local_ref) => write_local_ref(builder.init_variable(), local_ref), + model::Term::ListType { item_type } => builder.set_list_type(item_type.0), + model::Term::Str(value) => builder.set_string(value), + model::Term::StrType => builder.set_string_type(()), + model::Term::Nat(value) => builder.set_nat(*value), + model::Term::NatType => builder.set_nat_type(()), + model::Term::ExtSetType => builder.set_ext_set_type(()), + model::Term::Adt { variants } => builder.set_adt(variants.0), + model::Term::Quote { r#type } => builder.set_quote(r#type.0), + model::Term::Control { values } => builder.set_control(values.0), + model::Term::ControlType => builder.set_control_type(()), + + model::Term::Apply { global, args } => { + let mut builder = builder.init_apply(); + write_global_ref(builder.reborrow().init_global(), global); + let _ = builder.set_args(model::TermId::unwrap_slice(args)); + } + + model::Term::ApplyFull { global, args } => { + let mut builder = builder.init_apply_full(); + write_global_ref(builder.reborrow().init_global(), global); + let _ = builder.set_args(model::TermId::unwrap_slice(args)); + } + + model::Term::List { items, tail } => { + let mut builder = builder.init_list(); + let _ = builder.set_items(model::TermId::unwrap_slice(items)); + builder.set_tail(tail.map_or(0, |t| t.0 + 1)); + } + + model::Term::ExtSet { extensions, rest } => { + let mut builder = builder.init_ext_set(); + let _ = builder.set_extensions(*extensions); + builder.set_rest(rest.map_or(0, |t| t.0 + 1)); + } + + model::Term::FuncType { + inputs, + outputs, + extensions, + } => { + let mut builder = builder.init_func_type(); + builder.set_inputs(inputs.0); + builder.set_outputs(outputs.0); + builder.set_extensions(extensions.0); + } + } +} diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index d732803ca..6a2ac6bf5 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -90,6 +90,7 @@ use smol_str::SmolStr; use thiserror::Error; +pub mod binary; pub mod text; macro_rules! define_index { @@ -162,7 +163,7 @@ pub struct Module<'a> { /// Table of [`Node`]s. pub nodes: Vec>, /// Table of [`Region`]s. - pub region: Vec>, + pub regions: Vec>, /// Table of [`Term`]s. pub terms: Vec>, } @@ -171,12 +172,18 @@ impl<'a> Module<'a> { /// Return the node data for a given node id. #[inline] pub fn get_node(&self, node_id: NodeId) -> Option<&Node<'a>> { - self.nodes.get(node_id.0 as usize) + self.nodes.get(node_id.index()) + } + + /// Return a mutable reference to the node data for a given node id. + #[inline] + pub fn get_node_mut(&mut self, node_id: NodeId) -> Option<&mut Node<'a>> { + self.nodes.get_mut(node_id.index()) } /// Insert a new node into the module and return its id. pub fn insert_node(&mut self, node: Node<'a>) -> NodeId { - let id = NodeId(self.nodes.len() as u32); + let id = NodeId::new(self.nodes.len()); self.nodes.push(node); id } @@ -184,12 +191,18 @@ impl<'a> Module<'a> { /// Return the term data for a given term id. #[inline] pub fn get_term(&self, term_id: TermId) -> Option<&Term<'a>> { - self.terms.get(term_id.0 as usize) + self.terms.get(term_id.index()) + } + + /// Return a mutable reference to the term data for a given term id. + #[inline] + pub fn get_term_mut(&mut self, term_id: TermId) -> Option<&mut Term<'a>> { + self.terms.get_mut(term_id.index()) } /// Insert a new term into the module and return its id. pub fn insert_term(&mut self, term: Term<'a>) -> TermId { - let id = TermId(self.terms.len() as u32); + let id = TermId::new(self.terms.len()); self.terms.push(term); id } @@ -197,26 +210,32 @@ impl<'a> Module<'a> { /// Return the region data for a given region id. #[inline] pub fn get_region(&self, region_id: RegionId) -> Option<&Region<'a>> { - self.region.get(region_id.0 as usize) + self.regions.get(region_id.index()) + } + + /// Return a mutable reference to the region data for a given region id. + #[inline] + pub fn get_region_mut(&mut self, region_id: RegionId) -> Option<&mut Region<'a>> { + self.regions.get_mut(region_id.index()) } /// Insert a new region into the module and return its id. pub fn insert_region(&mut self, region: Region<'a>) -> RegionId { - let id = RegionId(self.region.len() as u32); - self.region.push(region); + let id = RegionId::new(self.regions.len()); + self.regions.push(region); id } } /// Nodes in the hugr graph. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] pub struct Node<'a> { /// The operation that the node performs. pub operation: Operation<'a>, /// The input ports of the node. - pub inputs: &'a [Port<'a>], + pub inputs: &'a [LinkRef<'a>], /// The output ports of the node. - pub outputs: &'a [Port<'a>], + pub outputs: &'a [LinkRef<'a>], /// The parameters of the node. pub params: &'a [TermId], /// The regions of the node. @@ -224,12 +243,20 @@ pub struct Node<'a> { /// The meta information attached to the node. pub meta: &'a [MetaItem<'a>], /// The signature of the node. - pub signature: TermId, + /// + /// Can be `None` to indicate that the node's signature should be inferred, + /// or for nodes with operations that do not have a signature. + pub signature: Option, } /// Operations that nodes can perform. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] pub enum Operation<'a> { + /// Invalid operation to be used as a placeholder. + /// This is useful for modules that have non-contiguous node ids, or modules + /// that have not yet been fully constructed. + #[default] + Invalid, /// Data flow graphs. Dfg, /// Control flow graphs. @@ -323,15 +350,17 @@ pub struct Region<'a> { /// The kind of the region. See [`RegionKind`] for details. pub kind: RegionKind, /// The source ports of the region. - pub sources: &'a [Port<'a>], + pub sources: &'a [LinkRef<'a>], /// The target ports of the region. - pub targets: &'a [Port<'a>], + pub targets: &'a [LinkRef<'a>], /// The nodes in the region. The order of the nodes is not significant. pub children: &'a [NodeId], /// The metadata attached to the region. pub meta: &'a [MetaItem<'a>], /// The signature of the region. - pub signature: TermId, + /// + /// Can be `None` to indicate that the region signature should be inferred. + pub signature: Option, } /// The kind of a region. @@ -343,15 +372,6 @@ pub enum RegionKind { ControlFlow, } -/// A port attached to a [`Node`] or [`Region`]. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct Port<'a> { - /// The link that the port is connected to. - pub link: LinkRef<'a>, - /// The type of the port. - pub r#type: Option, -} - /// A function declaration. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct FuncDecl<'a> { @@ -408,8 +428,8 @@ impl std::fmt::Display for GlobalRef<'_> { /// Local variables are defined as parameters to nodes. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum LocalRef<'a> { - /// Reference to the local variable by its parameter index. - Index(u16), + /// Reference to the local variable by its parameter index and its defining node. + Index(NodeId, u16), /// Reference to the local variable by its name. Named(&'a str), } @@ -417,7 +437,7 @@ pub enum LocalRef<'a> { impl std::fmt::Display for LocalRef<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - LocalRef::Index(index) => write!(f, "?:{}", index), + LocalRef::Index(node, index) => write!(f, "?:{}:{}", node.index(), index), LocalRef::Named(name) => write!(f, "?{}", name), } } @@ -442,9 +462,10 @@ impl std::fmt::Display for LinkRef<'_> { } /// A term in the compile time meta language. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] pub enum Term<'a> { /// Standin for any term. + #[default] Wildcard, /// The type of runtime types. @@ -603,12 +624,6 @@ pub enum Term<'a> { ControlType, } -impl<'a> Default for Term<'a> { - fn default() -> Self { - Self::Wildcard - } -} - /// A parameter to a function or alias. /// /// Parameter names must be unique within a parameter list. @@ -678,4 +693,7 @@ pub enum ModelError { /// defines two cases for the same tag. #[error("condition node is malformed: {0:?}")] MalformedCondition(NodeId), + /// There is a node that is not well-formed or has the invalid operation. + #[error("invalid operation on node: {0:?}")] + InvalidOperation(NodeId), } diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index 37e9dad06..5772efb51 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -13,7 +13,7 @@ module = { "(" ~ "hugr" ~ "0" ~ ")" ~ meta* ~ node* ~ EOI } meta = { "(" ~ "meta" ~ symbol ~ term ~ ")" } edge_name = @{ "%" ~ (ASCII_ALPHANUMERIC | "_" | "-")* } -port = { edge_name | ("(" ~ edge_name ~ term ~ meta* ~ ")") } +port = { edge_name } port_list = { "[" ~ port* ~ "]" } port_lists = _{ port_list ~ port_list } @@ -33,21 +33,21 @@ node = { | node_custom } -node_dfg = { "(" ~ "dfg" ~ port_lists? ~ type_hint? ~ meta* ~ region* ~ ")" } -node_cfg = { "(" ~ "cfg" ~ port_lists? ~ type_hint? ~ meta* ~ region* ~ ")" } -node_block = { "(" ~ "block" ~ port_lists? ~ type_hint? ~ meta* ~ region* ~ ")" } +node_dfg = { "(" ~ "dfg" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } +node_cfg = { "(" ~ "cfg" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } +node_block = { "(" ~ "block" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } node_define_func = { "(" ~ "define-func" ~ func_header ~ meta* ~ region* ~ ")" } node_declare_func = { "(" ~ "declare-func" ~ func_header ~ meta* ~ ")" } -node_call_func = { "(" ~ "call" ~ term ~ port_lists? ~ type_hint? ~ meta* ~ ")" } -node_load_func = { "(" ~ "load-func" ~ term ~ port_lists? ~ type_hint? ~ meta* ~ ")" } +node_call_func = { "(" ~ "call" ~ term ~ port_lists? ~ signature? ~ meta* ~ ")" } +node_load_func = { "(" ~ "load-func" ~ term ~ port_lists? ~ signature? ~ meta* ~ ")" } node_define_alias = { "(" ~ "define-alias" ~ alias_header ~ term ~ meta* ~ ")" } node_declare_alias = { "(" ~ "declare-alias" ~ alias_header ~ meta* ~ ")" } -node_tail_loop = { "(" ~ "tail-loop" ~ port_lists? ~ type_hint? ~ meta* ~ region* ~ ")" } -node_cond = { "(" ~ "cond" ~ port_lists? ~ type_hint? ~ meta* ~ region* ~ ")" } -node_tag = { "(" ~ "tag" ~ tag ~ port_lists? ~ type_hint? ~ meta* ~ region* ~ ")" } -node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ type_hint? ~ meta* ~ region* ~ ")" } +node_tail_loop = { "(" ~ "tail-loop" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } +node_cond = { "(" ~ "cond" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } +node_tag = { "(" ~ "tag" ~ tag ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } +node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } -type_hint = { "(" ~ "type" ~ term ~ ")" } +signature = { "(" ~ "signature" ~ term ~ ")" } func_header = { symbol ~ param* ~ term ~ term ~ term } alias_header = { symbol ~ param* ~ term } @@ -58,8 +58,8 @@ param_explicit = { "(" ~ "param" ~ term_var ~ term ~ ")" } param_constraint = { "(" ~ "where" ~ term ~ ")" } region = { region_dfg | region_cfg } -region_dfg = { "(" ~ "dfg" ~ port_lists? ~ type_hint? ~ meta* ~ node* ~ ")" } -region_cfg = { "(" ~ "cfg" ~ port_lists? ~ type_hint? ~ meta* ~ node* ~ ")" } +region_dfg = { "(" ~ "dfg" ~ port_lists? ~ signature? ~ meta* ~ node* ~ ")" } +region_cfg = { "(" ~ "cfg" ~ port_lists? ~ signature? ~ meta* ~ node* ~ ")" } term = { term_wildcard diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index fec67bee8..c139f3a3e 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -7,7 +7,7 @@ use thiserror::Error; use crate::v0::{ AliasDecl, FuncDecl, GlobalRef, LinkRef, LocalRef, MetaItem, Module, Node, NodeId, Operation, - Param, Port, Region, RegionId, RegionKind, Term, TermId, + Param, Region, RegionId, RegionKind, Term, TermId, }; mod pest_parser { @@ -62,7 +62,6 @@ impl<'a> ParseContext<'a> { fn parse_module(&mut self, pair: Pair<'a, Rule>) -> ParseResult<()> { debug_assert!(matches!(pair.as_rule(), Rule::module)); let mut inner = pair.into_inner(); - let r#type = self.module.insert_term(Term::Wildcard); let meta = self.parse_meta(&mut inner)?; let children = self.parse_nodes(&mut inner)?; @@ -73,7 +72,7 @@ impl<'a> ParseContext<'a> { targets: &[], children, meta, - signature: r#type, + signature: None, }); self.module.root = root_region; @@ -163,6 +162,7 @@ impl<'a> ParseContext<'a> { Rule::term_str => { // TODO: Escaping? let value = inner.next().unwrap().as_str(); + let value = &value[1..value.len() - 1]; Term::Str(value) } @@ -228,7 +228,7 @@ impl<'a> ParseContext<'a> { Rule::node_dfg => { let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; - let r#type = self.parse_type_hint(&mut inner)?; + let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; let regions = self.parse_regions(&mut inner)?; Node { @@ -238,14 +238,14 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, - signature: r#type, + signature, } } Rule::node_cfg => { let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; - let r#type = self.parse_type_hint(&mut inner)?; + let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; let regions = self.parse_regions(&mut inner)?; Node { @@ -255,14 +255,14 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, - signature: r#type, + signature, } } Rule::node_block => { let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; - let r#type = self.parse_type_hint(&mut inner)?; + let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; let regions = self.parse_regions(&mut inner)?; Node { @@ -272,13 +272,12 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, - signature: r#type, + signature, } } Rule::node_define_func => { let decl = self.parse_func_header(inner.next().unwrap())?; - let r#type = self.module.insert_term(Term::Wildcard); let meta = self.parse_meta(&mut inner)?; let regions = self.parse_regions(&mut inner)?; Node { @@ -288,13 +287,12 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, - signature: r#type, + signature: None, } } Rule::node_declare_func => { let decl = self.parse_func_header(inner.next().unwrap())?; - let r#type = self.module.insert_term(Term::Wildcard); let meta = self.parse_meta(&mut inner)?; Node { operation: Operation::DeclareFunc { decl }, @@ -303,7 +301,7 @@ impl<'a> ParseContext<'a> { params: &[], regions: &[], meta, - signature: r#type, + signature: None, } } @@ -311,7 +309,7 @@ impl<'a> ParseContext<'a> { let func = self.parse_term(inner.next().unwrap())?; let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; - let r#type = self.parse_type_hint(&mut inner)?; + let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; Node { operation: Operation::CallFunc { func }, @@ -320,7 +318,7 @@ impl<'a> ParseContext<'a> { params: &[], regions: &[], meta, - signature: r#type, + signature, } } @@ -328,7 +326,7 @@ impl<'a> ParseContext<'a> { let func = self.parse_term(inner.next().unwrap())?; let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; - let r#type = self.parse_type_hint(&mut inner)?; + let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; Node { operation: Operation::LoadFunc { func }, @@ -337,14 +335,13 @@ impl<'a> ParseContext<'a> { params: &[], regions: &[], meta, - signature: r#type, + signature, } } Rule::node_define_alias => { let decl = self.parse_alias_header(inner.next().unwrap())?; let value = self.parse_term(inner.next().unwrap())?; - let r#type = self.module.insert_term(Term::Wildcard); let meta = self.parse_meta(&mut inner)?; Node { operation: Operation::DefineAlias { decl, value }, @@ -353,13 +350,12 @@ impl<'a> ParseContext<'a> { params: &[], regions: &[], meta, - signature: r#type, + signature: None, } } Rule::node_declare_alias => { let decl = self.parse_alias_header(inner.next().unwrap())?; - let r#type = self.module.insert_term(Term::Wildcard); let meta = self.parse_meta(&mut inner)?; Node { operation: Operation::DeclareAlias { decl }, @@ -368,7 +364,7 @@ impl<'a> ParseContext<'a> { params: &[], regions: &[], meta, - signature: r#type, + signature: None, } } @@ -397,7 +393,7 @@ impl<'a> ParseContext<'a> { let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; - let r#type = self.parse_type_hint(&mut inner)?; + let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; let regions = self.parse_regions(&mut inner)?; Node { @@ -407,14 +403,14 @@ impl<'a> ParseContext<'a> { params: self.bump.alloc_slice_copy(¶ms), regions, meta, - signature: r#type, + signature, } } Rule::node_tail_loop => { let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; - let r#type = self.parse_type_hint(&mut inner)?; + let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; let regions = self.parse_regions(&mut inner)?; Node { @@ -424,14 +420,14 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, - signature: r#type, + signature, } } Rule::node_cond => { let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; - let r#type = self.parse_type_hint(&mut inner)?; + let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; let regions = self.parse_regions(&mut inner)?; Node { @@ -441,7 +437,7 @@ impl<'a> ParseContext<'a> { params: &[], regions, meta, - signature: r#type, + signature, } } @@ -449,7 +445,7 @@ impl<'a> ParseContext<'a> { let tag = inner.next().unwrap().as_str().parse::().unwrap(); let inputs = self.parse_port_list(&mut inner)?; let outputs = self.parse_port_list(&mut inner)?; - let r#type = self.parse_type_hint(&mut inner)?; + let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; Node { operation: Operation::Tag { tag }, @@ -458,7 +454,7 @@ impl<'a> ParseContext<'a> { params: &[], regions: &[], meta, - signature: r#type, + signature, } } @@ -492,7 +488,7 @@ impl<'a> ParseContext<'a> { let sources = self.parse_port_list(&mut inner)?; let targets = self.parse_port_list(&mut inner)?; - let r#type = self.parse_type_hint(&mut inner)?; + let signature = self.parse_signature(&mut inner)?; let meta = self.parse_meta(&mut inner)?; let children = self.parse_nodes(&mut inner)?; @@ -502,7 +498,7 @@ impl<'a> ParseContext<'a> { targets, children, meta, - signature: r#type, + signature, })) } @@ -589,46 +585,37 @@ impl<'a> ParseContext<'a> { Ok(self.bump.alloc_slice_copy(¶ms)) } - fn parse_type_hint(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult { - let Some(Rule::type_hint) = pairs.peek().map(|p| p.as_rule()) else { - return Ok(self.module.insert_term(Term::Wildcard)); + fn parse_signature(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult> { + let Some(Rule::signature) = pairs.peek().map(|p| p.as_rule()) else { + return Ok(None); }; let pair = pairs.next().unwrap(); - self.parse_term(pair.into_inner().next().unwrap()) + let signature = self.parse_term(pair.into_inner().next().unwrap())?; + Ok(Some(signature)) } - fn parse_port_list(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [Port<'a>]> { + fn parse_port_list(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [LinkRef<'a>]> { let Some(Rule::port_list) = pairs.peek().map(|p| p.as_rule()) else { return Ok(&[]); }; let pair = pairs.next().unwrap(); let inner = pair.into_inner(); - let mut ports = Vec::new(); + let mut links = Vec::new(); for token in inner { - let port = self.parse_port(token)?; - ports.push(port); + links.push(self.parse_port(token)?); } - Ok(self.bump.alloc_slice_copy(&ports)) + Ok(self.bump.alloc_slice_copy(&links)) } - fn parse_port(&mut self, pair: Pair<'a, Rule>) -> ParseResult> { + fn parse_port(&mut self, pair: Pair<'a, Rule>) -> ParseResult> { debug_assert!(matches!(pair.as_rule(), Rule::port)); - let mut inner = pair.into_inner(); - - let link = LinkRef::Named(inner.next().unwrap().as_str()); - - let mut r#type = None; - - if inner.peek().is_some() { - r#type = Some(self.parse_term(inner.next().unwrap())?); - } - - Ok(Port { link, r#type }) + let link = LinkRef::Named(&inner.next().unwrap().as_str()[1..]); + Ok(link) } fn parse_meta(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [MetaItem<'a>]> { diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index fe585a976..c70dfd401 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -2,8 +2,8 @@ use pretty::{Arena, DocAllocator, RefDoc}; use std::borrow::Cow; use crate::v0::{ - GlobalRef, LinkRef, LocalRef, MetaItem, ModelError, Module, NodeId, Operation, Param, Port, - RegionId, RegionKind, Term, TermId, + GlobalRef, LinkRef, LocalRef, MetaItem, ModelError, Module, NodeId, Operation, Param, RegionId, + RegionKind, Term, TermId, }; type PrintError = ModelError; @@ -143,33 +143,31 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { .ok_or_else(|| PrintError::NodeNotFound(node_id))?; self.print_parens(|this| match &node_data.operation { + Operation::Invalid => Err(ModelError::InvalidOperation(node_id)), Operation::Dfg => { this.print_group(|this| { this.print_text("dfg"); - this.print_port_list(node_data.inputs)?; - this.print_port_list(node_data.outputs) + this.print_port_lists(node_data.inputs, node_data.outputs) })?; - this.print_type_hint(node_data.signature)?; + this.print_signature(node_data.signature)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } Operation::Cfg => { this.print_group(|this| { this.print_text("cfg"); - this.print_port_list(node_data.inputs)?; - this.print_port_list(node_data.outputs) + this.print_port_lists(node_data.inputs, node_data.outputs) })?; - this.print_type_hint(node_data.signature)?; + this.print_signature(node_data.signature)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } Operation::Block => { this.print_group(|this| { this.print_text("block"); - this.print_port_list(node_data.inputs)?; - this.print_port_list(node_data.outputs) + this.print_port_lists(node_data.inputs, node_data.outputs) })?; - this.print_type_hint(node_data.signature)?; + this.print_signature(node_data.signature)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -238,10 +236,9 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_group(|this| { this.print_text("call"); this.print_term(*func)?; - this.print_port_list(node_data.inputs)?; - this.print_port_list(node_data.outputs) + this.print_port_lists(node_data.inputs, node_data.outputs) })?; - this.print_type_hint(node_data.signature)?; + this.print_signature(node_data.signature)?; this.print_meta(node_data.meta)?; Ok(()) } @@ -250,10 +247,9 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_group(|this| { this.print_text("load-func"); this.print_term(*func)?; - this.print_port_list(node_data.inputs)?; - this.print_port_list(node_data.outputs) + this.print_port_lists(node_data.inputs, node_data.outputs) })?; - this.print_type_hint(node_data.signature)?; + this.print_signature(node_data.signature)?; this.print_meta(node_data.meta)?; Ok(()) } @@ -274,10 +270,9 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { })?; } - this.print_port_list(node_data.inputs)?; - this.print_port_list(node_data.outputs) + this.print_port_lists(node_data.inputs, node_data.outputs) })?; - this.print_type_hint(node_data.signature)?; + this.print_signature(node_data.signature)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -295,10 +290,9 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) })?; - this.print_port_list(node_data.inputs)?; - this.print_port_list(node_data.outputs) + this.print_port_lists(node_data.inputs, node_data.outputs) })?; - this.print_type_hint(node_data.signature)?; + this.print_signature(node_data.signature)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -335,18 +329,16 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Operation::TailLoop => { this.print_text("tail-loop"); - this.print_port_list(node_data.inputs)?; - this.print_port_list(node_data.outputs)?; - this.print_type_hint(node_data.signature)?; + this.print_port_lists(node_data.inputs, node_data.outputs)?; + this.print_signature(node_data.signature)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } Operation::Conditional => { this.print_text("cond"); - this.print_port_list(node_data.inputs)?; - this.print_port_list(node_data.outputs)?; - this.print_type_hint(node_data.signature)?; + this.print_port_lists(node_data.inputs, node_data.outputs)?; + this.print_signature(node_data.signature)?; this.print_meta(node_data.meta)?; this.print_regions(node_data.regions) } @@ -354,9 +346,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Operation::Tag { tag } => { this.print_text("tag"); this.print_text(format!("{}", tag)); - this.print_port_list(node_data.inputs)?; - this.print_port_list(node_data.outputs)?; - this.print_type_hint(node_data.signature)?; + this.print_port_lists(node_data.inputs, node_data.outputs)?; + this.print_signature(node_data.signature)?; this.print_meta(node_data.meta) } }) @@ -385,12 +376,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { } }; - if !region_data.sources.is_empty() || !region_data.targets.is_empty() { - this.print_port_list(region_data.sources)?; - this.print_port_list(region_data.targets)?; - } - - this.print_type_hint(region_data.signature)?; + this.print_port_lists(region_data.sources, region_data.targets)?; + this.print_signature(region_data.signature)?; this.print_meta(region_data.meta)?; this.print_nodes(region) }) @@ -409,25 +396,26 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) } - fn print_port_list(&mut self, ports: &'a [Port<'a>]) -> PrintResult<()> { - self.print_brackets(|this| { - for port in ports { - if port.r#type.is_some() { - this.print_parens(|this| { - this.print_link_ref(port.link); - - match port.r#type { - Some(r#type) => this.print_term(r#type)?, - None => this.print_text("_"), - }; + fn print_port_lists( + &mut self, + first: &'a [LinkRef<'a>], + second: &'a [LinkRef<'a>], + ) -> PrintResult<()> { + if !first.is_empty() && !second.is_empty() { + self.print_group(|this| { + this.print_port_list(first)?; + this.print_port_list(second) + }) + } else { + Ok(()) + } + } - Ok(()) - })?; - } else { - this.print_link_ref(port.link); - } + fn print_port_list(&mut self, links: &'a [LinkRef<'a>]) -> PrintResult<()> { + self.print_brackets(|this| { + for link in links { + this.print_link_ref(*link); } - Ok(()) }) } @@ -585,7 +573,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { fn print_local_ref(&mut self, local_ref: LocalRef<'a>) -> PrintResult<()> { let name = match local_ref { - LocalRef::Index(i) => { + LocalRef::Index(_, i) => { let Some(name) = self.locals.get(i as usize) else { return Err(PrintError::InvalidLocal(local_ref.to_string())); }; @@ -636,14 +624,14 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) } - fn print_type_hint(&mut self, term: TermId) -> PrintResult<()> { - if let Some(Term::Wildcard) = self.module.get_term(term) { - return Ok(()); + fn print_signature(&mut self, term: Option) -> PrintResult<()> { + if let Some(term) = term { + self.print_parens(|this| { + this.print_text("signature"); + this.print_term(term) + })?; } - self.print_parens(|this| { - this.print_text("type-hint"); - this.print_term(term) - }) + Ok(()) } } diff --git a/hugr-model/tests/binary.rs b/hugr-model/tests/binary.rs new file mode 100644 index 000000000..57123aa4d --- /dev/null +++ b/hugr-model/tests/binary.rs @@ -0,0 +1,58 @@ +use bumpalo::Bump; +use hugr_model::v0 as model; +use pretty_assertions::assert_eq; + +/// Reads a module from a string, serializes it to binary, and then deserializes it back to a module. +/// The original and deserialized modules are compared for equality. +pub fn binary_roundtrip(input: &str) { + let bump = Bump::new(); + let parsed_module = model::text::parse(input, &bump).unwrap(); + let bytes = model::binary::write_to_vec(&parsed_module.module); + let deserialized_module = model::binary::read_from_slice(&bytes, &bump).unwrap(); + assert_eq!(parsed_module.module, deserialized_module); +} + +#[test] +pub fn test_add() { + binary_roundtrip(include_str!("../../hugr-core/tests/fixtures/model-add.edn")); +} + +#[test] +pub fn test_alias() { + binary_roundtrip(include_str!( + "../../hugr-core/tests/fixtures/model-alias.edn" + )); +} + +#[test] +pub fn test_call() { + binary_roundtrip(include_str!( + "../../hugr-core/tests/fixtures/model-call.edn" + )); +} + +#[test] +pub fn test_cfg() { + binary_roundtrip(include_str!("../../hugr-core/tests/fixtures/model-cfg.edn")); +} + +#[test] +pub fn test_cond() { + binary_roundtrip(include_str!( + "../../hugr-core/tests/fixtures/model-cond.edn" + )); +} + +#[test] +pub fn test_loop() { + binary_roundtrip(include_str!( + "../../hugr-core/tests/fixtures/model-loop.edn" + )); +} + +#[test] +pub fn test_params() { + binary_roundtrip(include_str!( + "../../hugr-core/tests/fixtures/model-params.edn" + )); +} diff --git a/hugr/Cargo.toml b/hugr/Cargo.toml index 290cc58d5..dfe85287a 100644 --- a/hugr/Cargo.toml +++ b/hugr/Cargo.toml @@ -24,8 +24,10 @@ path = "src/lib.rs" [features] extension_inference = ["hugr-core/extension_inference"] declarative = ["hugr-core/declarative"] +model_unstable = ["hugr-core/model_unstable", "hugr-model"] [dependencies] +hugr-model = { path = "../hugr-model", optional = true, version = "0.1.0" } hugr-core = { path = "../hugr-core", version = "0.10.0" } hugr-passes = { path = "../hugr-passes", version = "0.8.2" } @@ -34,7 +36,11 @@ rstest = { workspace = true } lazy_static = { workspace = true } criterion = { workspace = true, features = ["html_reports"] } serde_json = { workspace = true } +bumpalo = { workspace = true, features = ["collections"] } [[bench]] name = "bench_main" harness = false + +[profile.bench] +debug = true diff --git a/hugr/benches/benchmarks/hugr.rs b/hugr/benches/benchmarks/hugr.rs index e26fd2d3d..4139e882c 100644 --- a/hugr/benches/benchmarks/hugr.rs +++ b/hugr/benches/benchmarks/hugr.rs @@ -9,9 +9,9 @@ use hugr::extension::prelude::{BOOL_T, QB_T, USIZE_T}; use hugr::extension::PRELUDE_REGISTRY; use hugr::ops::OpName; use hugr::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY; -use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; +use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; use hugr::types::Signature; -use hugr::{type_row, CircuitUnit, Extension, Hugr}; +use hugr::{type_row, Extension, Hugr}; use lazy_static::lazy_static; pub fn simple_dfg_hugr() -> Hugr { let dfg_builder = @@ -67,6 +67,24 @@ impl Serializer for JsonSer { } } +#[cfg(feature = "model_unstable")] +struct CapnpSer; + +#[cfg(feature = "model_unstable")] +impl Serializer for CapnpSer { + fn serialize(&self, hugr: &Hugr) -> Vec { + let bump = bumpalo::Bump::new(); + let module = hugr_core::export::export_hugr(hugr, &bump); + hugr_model::v0::binary::write_to_vec(&module) + } + + fn deserialize(&self, bytes: &[u8]) -> Hugr { + let bump = bumpalo::Bump::new(); + let module = hugr_model::v0::binary::read_from_slice(bytes, &bump).unwrap(); + hugr_core::import::import_hugr(&module, &FLOAT_OPS_REGISTRY).unwrap() + } +} + fn roundtrip(hugr: &Hugr, serializer: impl Serializer) -> Hugr { let bytes = serializer.serialize(hugr); serializer.deserialize(&bytes) @@ -112,9 +130,9 @@ pub fn circuit(layers: usize) -> Hugr { let cx_gate = QUANTUM_EXT .instantiate_extension_op("CX", [], &PRELUDE_REGISTRY) .unwrap(); - let rz = QUANTUM_EXT - .instantiate_extension_op("Rz", [], &FLOAT_OPS_REGISTRY) - .unwrap(); + // let rz = QUANTUM_EXT + // .instantiate_extension_op("Rz", [], &FLOAT_OPS_REGISTRY) + // .unwrap(); let signature = Signature::new_endo(type_row![QB_T, QB_T]).with_extension_delta(QUANTUM_EXT.name().clone()); let mut module_builder = ModuleBuilder::new(); @@ -133,13 +151,14 @@ pub fn circuit(layers: usize) -> Hugr { .append(cx_gate.clone(), [1, 0]) .unwrap(); - let angle = linear.add_constant(ConstF64::new(0.5)); - linear - .append_and_consume( - rz.clone(), - [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)], - ) - .unwrap(); + // TODO: Currently left out because we can not represent constants in the model + // let angle = linear.add_constant(ConstF64::new(0.5)); + // linear + // .append_and_consume( + // rz.clone(), + // [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)], + // ) + // .unwrap(); } let outs = linear.finish(); @@ -157,13 +176,14 @@ fn bench_builder(c: &mut Criterion) { } fn bench_serialization(c: &mut Criterion) { - c.bench_function("simple_cfg_serialize", |b| { + c.bench_function("simple_cfg_serialize/json", |b| { let h = simple_cfg_hugr(); b.iter(|| { black_box(roundtrip(&h, JsonSer)); }); }); - let mut group = c.benchmark_group("circuit_roundtrip"); + + let mut group = c.benchmark_group("circuit_roundtrip/json"); group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); for size in [0, 1, 10, 100, 1000].iter() { group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { @@ -175,7 +195,7 @@ fn bench_serialization(c: &mut Criterion) { } group.finish(); - let mut group = c.benchmark_group("circuit_serialize"); + let mut group = c.benchmark_group("circuit_serialize/json"); group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); for size in [0, 1, 10, 100, 1000].iter() { group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { @@ -186,6 +206,21 @@ fn bench_serialization(c: &mut Criterion) { }); } group.finish(); + + #[cfg(feature = "model_unstable")] + { + let mut group = c.benchmark_group("circuit_roundtrip/capnp"); + group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); + for size in [0, 1, 10, 100, 1000].iter() { + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { + let h = circuit(size); + b.iter(|| { + black_box(roundtrip(&h, CapnpSer)); + }); + }); + } + group.finish(); + } } criterion_group! {