diff --git a/.github/workflows/ci-py.yml b/.github/workflows/ci-py.yml index fba1eae93..9cf8bc25b 100644 --- a/.github/workflows/ci-py.yml +++ b/.github/workflows/ci-py.yml @@ -141,7 +141,7 @@ jobs: - name: Upload python coverage to codecov.io if: github.event_name != 'merge_group' && matrix.python-version.coverage - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: files: coverage.xml name: python diff --git a/.github/workflows/ci-rs.yml b/.github/workflows/ci-rs.yml index ddd42dd45..0993b6408 100644 --- a/.github/workflows/ci-rs.yml +++ b/.github/workflows/ci-rs.yml @@ -230,7 +230,7 @@ jobs: - name: Generate coverage report run: cargo llvm-cov --all-features report --codecov --output-path coverage.json - name: Upload coverage to codecov.io - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: files: coverage.json name: rust diff --git a/.github/workflows/notify-coverage.yml b/.github/workflows/notify-coverage.yml index ad6070254..70b6ff187 100644 --- a/.github/workflows/notify-coverage.yml +++ b/.github/workflows/notify-coverage.yml @@ -22,9 +22,10 @@ jobs: if: needs.coverage-trend.outputs.should_notify == 'true' steps: - name: Send notification - uses: slackapi/slack-github-action@v1.27.0 + uses: slackapi/slack-github-action@v2.0.0 with: - channel-id: 'C04SHCL4FKP' - slack-message: ${{ needs.coverage-trend.outputs.msg }} - env: - SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} + method: chat.postMessage + token: ${{ secrets.SLACK_BOT_TOKEN }} + payload: | + channel: 'C04SHCL4FKP' + text: ${{ needs.coverage-trend.outputs.msg }} diff --git a/.github/workflows/unsoundness.yml b/.github/workflows/unsoundness.yml index 569fc7183..0c2790f6f 100644 --- a/.github/workflows/unsoundness.yml +++ b/.github/workflows/unsoundness.yml @@ -1,9 +1,9 @@ name: Unsoundness checks on: - push: - branches: - - main + schedule: + # Weekly on Monday at 04:00 UTC + - cron: '0 4 * * 1' workflow_dispatch: {} concurrency: @@ -40,19 +40,17 @@ jobs: run: cargo miri test - notify-slack: - uses: CQCL/hugrverse-actions/.github/workflows/slack-notifier.yml@main + create-issue: + uses: CQCL/hugrverse-actions/.github/workflows/create-issue.yml@main needs: miri if: always() && needs.miri.result == 'failure' && github.event_name == 'push' - with: - channel-id: 'C04SHCL4FKP' - slack-message: | - 💥 The unsoundness check for `CQCL/hugr` failed. - . - # Rate-limit the message to once per day - timeout-minutes: 1440 - # A repository variable used to store the last message timestamp. - timeout-variable: "UNSOUNDNESS_MSG_SENT" secrets: - SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} GITHUB_PAT: ${{ secrets.HUGRBOT_PAT }} + with: + title: "💥 Unsoundness check failed on main" + body: | + The unsoundness check for `CQCL/hugr` failed. + + [https://github.com/CQCL/hugr/actions/runs/${{ github.run_id }}](Please investigate). + unique-label: "unsoundness-checks" + other-labels: "bug" diff --git a/Cargo.toml b/Cargo.toml index e046c799f..36ef0526b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,6 @@ portgraph = { version = "0.12.2" } insta = { version = "1.34.0" } bitvec = "1.0.1" cgmath = "0.18.0" -context-iterators = "0.2.0" cool_asserts = "2.0.3" criterion = "0.5.1" delegate = "0.13.0" diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index 3dab5e709..84b83677f 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -47,7 +47,6 @@ bitvec = { workspace = true, features = ["serde"] } enum_dispatch = { workspace = true } lazy_static = { workspace = true } petgraph = { workspace = true } -context-iterators = { workspace = true } serde_json = { workspace = true } delegate = { workspace = true } paste = { workspace = true } diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 68f3a15c0..093368b60 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -7,7 +7,7 @@ use crate::{ type_param::{TypeArgVariable, TypeParam}, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg, - TypeBase, TypeEnum, + TypeBase, TypeBound, TypeEnum, }, Direction, Hugr, HugrView, IncomingPort, Node, Port, }; @@ -44,8 +44,21 @@ struct Context<'a> { bump: &'a Bump, /// Stores the terms that we have already seen to avoid duplicates. term_map: FxHashMap, model::TermId>, + /// The current scope for local variables. + /// + /// This is set to the id of the smallest enclosing node that defines a polymorphic type. + /// We use this when exporting local variables in terms. local_scope: Option, + + /// Constraints to be added to the local scope. + /// + /// When exporting a node that defines a polymorphic type, we use this field + /// to collect the constraints that need to be added to that polymorphic + /// type. Currently this is used to record `nonlinear` constraints on uses + /// of `TypeParam::Type` with a `TypeBound::Copyable` bound. + local_constraints: Vec, + /// Mapping from extension operations to their declarations. decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>, } @@ -63,13 +76,14 @@ impl<'a> Context<'a> { term_map: FxHashMap::default(), local_scope: None, decl_operations: FxHashMap::default(), + local_constraints: Vec::new(), } } /// Exports the root module of the HUGR graph. pub fn export_root(&mut self) { let hugr_children = self.hugr.children(self.hugr.root()); - let mut children = Vec::with_capacity(hugr_children.len()); + let mut children = Vec::with_capacity(hugr_children.size_hint().0); for child in self.hugr.children(self.hugr.root()) { children.push(self.export_node(child)); @@ -110,7 +124,7 @@ impl<'a> Context<'a> { num_ports: usize, ) -> &'a [model::LinkRef<'a>] { let ports = self.hugr.node_ports(node, direction); - let mut links = BumpVec::with_capacity_in(ports.len(), self.bump); + let mut links = BumpVec::with_capacity_in(ports.size_hint().0, self.bump); for port in ports.take(num_ports) { links.push(model::LinkRef::Id(self.get_link_id(node, port))); @@ -173,9 +187,11 @@ impl<'a> Context<'a> { } fn with_local_scope(&mut self, node: model::NodeId, f: impl FnOnce(&mut Self) -> T) -> T { - let old_scope = self.local_scope.replace(node); + let prev_local_scope = self.local_scope.replace(node); + let prev_local_constraints = std::mem::take(&mut self.local_constraints); let result = f(self); - self.local_scope = old_scope; + self.local_scope = prev_local_scope; + self.local_constraints = prev_local_constraints; result } @@ -232,10 +248,11 @@ impl<'a> Context<'a> { OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| { let name = this.get_func_name(node).unwrap(); - let (params, signature) = this.export_poly_func_type(&func.signature); + let (params, constraints, signature) = this.export_poly_func_type(&func.signature); let decl = this.bump.alloc(model::FuncDecl { name, params, + constraints, signature, }); let extensions = this.export_ext_set(&func.signature.body().extension_reqs); @@ -247,10 +264,11 @@ impl<'a> Context<'a> { OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| { let name = this.get_func_name(node).unwrap(); - let (params, func) = this.export_poly_func_type(&func.signature); + let (params, constraints, func) = this.export_poly_func_type(&func.signature); let decl = this.bump.alloc(model::FuncDecl { name, params, + constraints, signature: func, }); model::Operation::DeclareFunc { decl } @@ -450,10 +468,11 @@ impl<'a> Context<'a> { let decl = self.with_local_scope(node, |this| { let name = this.make_qualified_name(opdef.extension(), opdef.name()); - let (params, r#type) = this.export_poly_func_type(poly_func_type); + let (params, constraints, r#type) = this.export_poly_func_type(poly_func_type); let decl = this.bump.alloc(model::OperationDecl { name, params, + constraints, r#type, }); decl @@ -579,7 +598,7 @@ impl<'a> Context<'a> { let targets = self.make_ports(output_node, Direction::Incoming, output_op.types.len()); // Export the remaining children of the node. - let mut region_children = BumpVec::with_capacity_in(children.len(), self.bump); + let mut region_children = BumpVec::with_capacity_in(children.size_hint().0, self.bump); for child in children { region_children.push(self.export_node(child)); @@ -609,7 +628,7 @@ impl<'a> Context<'a> { /// Creates a control flow region from the given node's children. pub fn export_cfg(&mut self, node: Node) -> model::RegionId { let mut children = self.hugr.children(node); - let mut region_children = BumpVec::with_capacity_in(children.len() + 1, self.bump); + let mut region_children = BumpVec::with_capacity_in(children.size_hint().0 + 1, self.bump); // The first child is the entry block. // We create a source port on the control flow region and connect it to the @@ -623,16 +642,16 @@ impl<'a> Context<'a> { let source = model::LinkRef::Id(self.get_link_id(entry_block, IncomingPort::from(0))); region_children.push(self.export_node(entry_block)); - // Export the remaining children of the node, except for the last one. - for _ in 0..children.len() - 1 { - region_children.push(self.export_node(children.next().unwrap())); - } - // The last child is the exit block. // Contrary to the entry block, the exit block does not have a dataflow subgraph. // We therefore do not export the block itself, but simply use its output ports // as the target ports of the control flow region. - let exit_block = children.next().unwrap(); + let exit_block = children.next_back().unwrap(); + + // Export the remaining children of the node, except for the last one. + for child in children { + region_children.push(self.export_node(child)); + } let OpType::ExitBlock(_) = self.hugr.get_optype(exit_block) else { panic!("expected an `ExitBlock` node as the last child node"); @@ -657,7 +676,7 @@ impl<'a> Context<'a> { /// Export the `Case` node children of a `Conditional` node as data flow regions. pub fn export_conditional_regions(&mut self, node: Node) -> &'a [model::RegionId] { let children = self.hugr.children(node); - let mut regions = BumpVec::with_capacity_in(children.len(), self.bump); + let mut regions = BumpVec::with_capacity_in(children.size_hint().0, self.bump); for child in children { let OpType::Case(case_op) = self.hugr.get_optype(child) else { @@ -671,22 +690,36 @@ impl<'a> Context<'a> { regions.into_bump_slice() } + /// Exports a polymorphic function type. + /// + /// The returned triple consists of: + /// - The static parameters of the polymorphic function type. + /// - The constraints of the polymorphic function type. + /// - The function type itself. pub fn export_poly_func_type( &mut self, t: &PolyFuncTypeBase, - ) -> (&'a [model::Param<'a>], model::TermId) { + ) -> (&'a [model::Param<'a>], &'a [model::TermId], model::TermId) { let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump); + let scope = self + .local_scope + .expect("exporting poly func type outside of local scope"); for (i, param) in t.params().iter().enumerate() { let name = self.bump.alloc_str(&i.to_string()); - let r#type = self.export_type_param(param); - let param = model::Param::Implicit { name, r#type }; + let r#type = self.export_type_param(param, Some(model::LocalRef::Index(scope, i as _))); + let param = model::Param { + name, + r#type, + sort: model::ParamSort::Implicit, + }; params.push(param) } + let constraints = self.bump.alloc_slice_copy(&self.local_constraints); let body = self.export_func_type(t.body()); - (params.into_bump_slice(), body) + (params.into_bump_slice(), constraints, body) } pub fn export_type(&mut self, t: &TypeBase) -> model::TermId { @@ -703,7 +736,6 @@ impl<'a> Context<'a> { } TypeEnum::Function(func) => self.export_func_type(func), TypeEnum::Variable(index, _) => { - // This ignores the type bound for now let node = self.local_scope.expect("local variable out of scope"); self.make_term(model::Term::Var(model::LocalRef::Index(node, *index as _))) } @@ -794,20 +826,39 @@ impl<'a> Context<'a> { self.make_term(model::Term::List { items, tail: None }) } - pub fn export_type_param(&mut self, t: &TypeParam) -> model::TermId { + /// Exports a `TypeParam` to a term. + /// + /// The `var` argument is set when the type parameter being exported is the + /// type of a parameter to a polymorphic definition. In that case we can + /// generate a `nonlinear` constraint for the type of runtime types marked as + /// `TypeBound::Copyable`. + pub fn export_type_param( + &mut self, + t: &TypeParam, + var: Option>, + ) -> model::TermId { match t { - // This ignores the type bound for now. - TypeParam::Type { .. } => self.make_term(model::Term::Type), - // This ignores the type bound for now. + TypeParam::Type { b } => { + if let (Some(var), TypeBound::Copyable) = (var, b) { + let term = self.make_term(model::Term::Var(var)); + let non_linear = self.make_term(model::Term::NonLinearConstraint { term }); + self.local_constraints.push(non_linear); + } + + self.make_term(model::Term::Type) + } + // This ignores the bound on the natural for now. TypeParam::BoundedNat { .. } => self.make_term(model::Term::NatType), TypeParam::String => self.make_term(model::Term::StrType), TypeParam::List { param } => { - let item_type = self.export_type_param(param); + let item_type = self.export_type_param(param, None); self.make_term(model::Term::ListType { item_type }) } TypeParam::Tuple { params } => { let items = self.bump.alloc_slice_fill_iter( - params.iter().map(|param| self.export_type_param(param)), + params + .iter() + .map(|param| self.export_type_param(param, None)), ); let types = self.make_term(model::Term::List { items, tail: None }); self.make_term(model::Term::ApplyFull { diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 110da4124..b975018c6 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -436,6 +436,7 @@ impl OpDef { } /// Iterate over all miscellaneous data in the [OpDef]. + #[allow(unused)] // Unused when no features are enabled pub(crate) fn iter_misc(&self) -> impl ExactSizeIterator { self.misc.iter().map(|(k, v)| (k.as_str(), v)) } diff --git a/hugr-core/src/hugr/rewrite/inline_dfg.rs b/hugr-core/src/hugr/rewrite/inline_dfg.rs index ca3f39cd3..ff400a1d3 100644 --- a/hugr-core/src/hugr/rewrite/inline_dfg.rs +++ b/hugr-core/src/hugr/rewrite/inline_dfg.rs @@ -208,7 +208,7 @@ mod test { // Sanity checks assert_eq!( - outer.children(inner.node()).len(), + outer.children(inner.node()).count(), if nonlocal { 3 } else { 6 } ); // Input, Output, add; + const, load_const, lift assert_eq!(find_dfgs(&outer), vec![outer.root(), inner.node()]); @@ -217,7 +217,7 @@ mod test { outer.get_parent(outer.get_parent(add).unwrap()), outer.get_parent(sub) ); - assert_eq!(outer.nodes().len(), 11); // 6 above + inner DFG + outer (DFG + Input + Output + sub) + assert_eq!(outer.nodes().count(), 11); // 6 above + inner DFG + outer (DFG + Input + Output + sub) { // Check we can't inline the outer DFG let mut h = outer.clone(); @@ -230,7 +230,7 @@ mod test { outer.apply_rewrite(InlineDFG(*inner.handle()))?; outer.validate(®)?; - assert_eq!(outer.nodes().len(), 8); + assert_eq!(outer.nodes().count(), 8); assert_eq!(find_dfgs(&outer), vec![outer.root()]); let [_lift, add, sub] = extension_ops(&outer).try_into().unwrap(); assert_eq!(outer.get_parent(add), Some(outer.root())); @@ -265,8 +265,8 @@ mod test { let mut h = h.finish_hugr_with_outputs(cx.outputs(), ®)?; assert_eq!(find_dfgs(&h), vec![h.root(), swap.node()]); - assert_eq!(h.nodes().len(), 8); // Dfg+I+O, H, CX, Dfg+I+O - // No permutation outside the swap DFG: + assert_eq!(h.nodes().count(), 8); // Dfg+I+O, H, CX, Dfg+I+O + // No permutation outside the swap DFG: assert_eq!( h.node_connections(p_h.node(), swap.node()) .collect::>(), @@ -292,7 +292,7 @@ mod test { h.apply_rewrite(InlineDFG(*swap.handle()))?; assert_eq!(find_dfgs(&h), vec![h.root()]); - assert_eq!(h.nodes().len(), 5); // Dfg+I+O + assert_eq!(h.nodes().count(), 5); // Dfg+I+O let mut ops = extension_ops(&h); ops.sort_by_key(|n| h.num_outputs(*n)); // Put H before CX let [h_gate, cx] = ops.try_into().unwrap(); diff --git a/hugr-core/src/hugr/rewrite/outline_cfg.rs b/hugr-core/src/hugr/rewrite/outline_cfg.rs index c15e3660d..7dd181f92 100644 --- a/hugr-core/src/hugr/rewrite/outline_cfg.rs +++ b/hugr-core/src/hugr/rewrite/outline_cfg.rs @@ -453,8 +453,8 @@ mod test { // `add_hugr_with_wires` does not return an InsertionResult, so recover the nodes manually: let cfg = cfg.node(); let exit_node = h.children(cfg).nth(1).unwrap(); - let tail = h.input_neighbours(exit_node).exactly_one().unwrap(); - let head = h.input_neighbours(tail).exactly_one().unwrap(); + let tail = h.input_neighbours(exit_node).exactly_one().ok().unwrap(); + let head = h.input_neighbours(tail).exactly_one().ok().unwrap(); // Just sanity-check we have the correct nodes assert!(h.get_optype(exit_node).is_exit_block()); assert_eq!( diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index d17eaf44f..7d744c150 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -10,8 +10,6 @@ pub mod sibling_subgraph; #[cfg(test)] mod tests; -use std::iter::Map; - pub use self::petgraph::PetgraphWrapper; use self::render::RenderConfig; pub use descendants::DescendantsGraph; @@ -19,10 +17,9 @@ pub use root_checked::RootChecked; pub use sibling::SiblingGraph; pub use sibling_subgraph::SiblingSubgraph; -use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx}; -use itertools::{Itertools, MapInto}; +use itertools::Itertools; use portgraph::render::{DotFormat, MermaidFormat}; -use portgraph::{multiportgraph, LinkView, PortView}; +use portgraph::{LinkView, PortView}; use super::internal::HugrInternals; use super::{ @@ -40,36 +37,6 @@ use itertools::Either; /// A trait for inspecting HUGRs. /// For end users we intend this to be superseded by region-specific APIs. pub trait HugrView: HugrInternals { - /// An Iterator over the nodes in a Hugr(View) - type Nodes<'a>: Iterator - where - Self: 'a; - - /// An Iterator over (some or all) ports of a node - type NodePorts<'a>: Iterator - where - Self: 'a; - - /// An Iterator over the children of a node - type Children<'a>: Iterator - where - Self: 'a; - - /// An Iterator over (some or all) the nodes neighbouring a node - type Neighbours<'a>: Iterator - where - Self: 'a; - - /// Iterator over the children of a node - type PortLinks<'a>: Iterator - where - Self: 'a; - - /// Iterator over the links between two nodes. - type NodeConnections<'a>: Iterator - where - Self: 'a; - /// Return the root node of this view. #[inline] fn root(&self) -> Node { @@ -147,16 +114,16 @@ pub trait HugrView: HugrInternals { fn edge_count(&self) -> usize; /// Iterates over the nodes in the port graph. - fn nodes(&self) -> Self::Nodes<'_>; + fn nodes(&self) -> impl Iterator + Clone; /// Iterator over ports of node in a given direction. - fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_>; + fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone; /// Iterator over output ports of node. /// Like [`node_ports`][HugrView::node_ports]`(node, Direction::Outgoing)` /// but preserves knowledge that the ports are [OutgoingPort]s. #[inline] - fn node_outputs(&self, node: Node) -> OutgoingPorts> { + fn node_outputs(&self, node: Node) -> impl Iterator + Clone { self.node_ports(node, Direction::Outgoing) .map(|p| p.as_outgoing().unwrap()) } @@ -165,16 +132,20 @@ pub trait HugrView: HugrInternals { /// Like [`node_ports`][HugrView::node_ports]`(node, Direction::Incoming)` /// but preserves knowledge that the ports are [IncomingPort]s. #[inline] - fn node_inputs(&self, node: Node) -> IncomingPorts> { + fn node_inputs(&self, node: Node) -> impl Iterator + Clone { self.node_ports(node, Direction::Incoming) .map(|p| p.as_incoming().unwrap()) } /// Iterator over both the input and output ports of node. - fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_>; + fn all_node_ports(&self, node: Node) -> impl Iterator + Clone; /// Iterator over the nodes and ports connected to a port. - fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_>; + fn linked_ports( + &self, + node: Node, + port: impl Into, + ) -> impl Iterator + Clone; /// Iterator over all the nodes and ports connected to a node in a given direction. fn all_linked_ports( @@ -245,7 +216,7 @@ pub trait HugrView: HugrInternals { &self, node: Node, port: impl Into, - ) -> OutgoingNodePorts> { + ) -> impl Iterator { self.linked_ports(node, port.into()) .map(|(n, p)| (n, p.as_outgoing().unwrap())) } @@ -257,13 +228,13 @@ pub trait HugrView: HugrInternals { &self, node: Node, port: impl Into, - ) -> IncomingNodePorts> { + ) -> impl Iterator { self.linked_ports(node, port.into()) .map(|(n, p)| (n, p.as_incoming().unwrap())) } /// Iterator the links between two nodes. - fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_>; + fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone; /// Returns whether a port is connected. fn is_linked(&self, node: Node, port: impl Into) -> bool { @@ -288,28 +259,28 @@ pub trait HugrView: HugrInternals { } /// Return iterator over the direct children of node. - fn children(&self, node: Node) -> Self::Children<'_>; + fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone; /// Iterates over neighbour nodes in the given direction. /// May contain duplicates if the graph has multiple links between nodes. - fn neighbours(&self, node: Node, dir: Direction) -> Self::Neighbours<'_>; + fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone; /// Iterates over the input neighbours of the `node`. /// Shorthand for [`neighbours`][HugrView::neighbours]`(node, Direction::Incoming)`. #[inline] - fn input_neighbours(&self, node: Node) -> Self::Neighbours<'_> { + fn input_neighbours(&self, node: Node) -> impl Iterator + Clone { self.neighbours(node, Direction::Incoming) } /// Iterates over the output neighbours of the `node`. /// Shorthand for [`neighbours`][HugrView::neighbours]`(node, Direction::Outgoing)`. #[inline] - fn output_neighbours(&self, node: Node) -> Self::Neighbours<'_> { + fn output_neighbours(&self, node: Node) -> impl Iterator + Clone { self.neighbours(node, Direction::Outgoing) } /// Iterates over the input and output neighbours of the `node` in sequence. - fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_>; + fn all_neighbours(&self, node: Node) -> impl Iterator + Clone; /// Get the input and output child nodes of a dataflow parent. /// If the node isn't a dataflow parent, then return None @@ -469,18 +440,6 @@ pub trait HugrView: HugrInternals { } } -/// Wraps an iterator over [Port]s that are known to be [OutgoingPort]s -pub type OutgoingPorts = Map OutgoingPort>; - -/// Wraps an iterator over [Port]s that are known to be [IncomingPort]s -pub type IncomingPorts = Map IncomingPort>; - -/// Wraps an iterator over `(`[`Node`],[`Port`]`)` when the ports are known to be [OutgoingPort]s -pub type OutgoingNodePorts = Map (Node, OutgoingPort)>; - -/// Wraps an iterator over `(`[`Node`],[`Port`]`)` when the ports are known to be [IncomingPort]s -pub type IncomingNodePorts = Map (Node, IncomingPort)>; - /// Trait for views that provides a guaranteed bound on the type of the root node. pub trait RootTagged: HugrView { /// The kind of handle that can be used to refer to the root node. @@ -555,25 +514,6 @@ impl ExtractHugr for &mut Hugr { } impl> HugrView for T { - /// An Iterator over the nodes in a Hugr(View) - type Nodes<'a> = MapInto, Node> where Self: 'a; - - /// An Iterator over (some or all) ports of a node - type NodePorts<'a> = MapInto where Self: 'a; - - /// An Iterator over the children of a node - type Children<'a> = MapInto, Node> where Self: 'a; - - /// An Iterator over (some or all) the nodes neighbouring a node - type Neighbours<'a> = MapInto, Node> where Self: 'a; - - /// Iterator over the children of a node - type PortLinks<'a> = MapWithCtx, &'a Hugr, (Node, Port)> - where - Self: 'a; - - type NodeConnections<'a> = MapWithCtx,&'a Hugr, [Port; 2]> where Self: 'a; - #[inline] fn contains_node(&self, node: Node) -> bool { self.as_ref().graph.contains_node(node.pg_index()) @@ -590,12 +530,12 @@ impl> HugrView for T { } #[inline] - fn nodes(&self) -> Self::Nodes<'_> { + fn nodes(&self) -> impl Iterator + Clone { self.as_ref().graph.nodes_iter().map_into() } #[inline] - fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_> { + fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.as_ref() .graph .port_offsets(node.pg_index(), dir) @@ -603,7 +543,7 @@ impl> HugrView for T { } #[inline] - fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_> { + fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { self.as_ref() .graph .all_port_offsets(node.pg_index()) @@ -611,36 +551,33 @@ impl> HugrView for T { } #[inline] - fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_> { + fn linked_ports( + &self, + node: Node, + port: impl Into, + ) -> impl Iterator + Clone { let port = port.into(); let hugr = self.as_ref(); let port = hugr .graph .port_index(node.pg_index(), port.pg_offset()) .unwrap(); - hugr.graph - .port_links(port) - .with_context(hugr) - .map_with_context(|(_, link), hugr| { - let port = link.port(); - let node = hugr.graph.port_node(port).unwrap(); - let offset = hugr.graph.port_offset(port).unwrap(); - (node.into(), offset.into()) - }) + hugr.graph.port_links(port).map(|(_, link)| { + let port = link.port(); + let node = hugr.graph.port_node(port).unwrap(); + let offset = hugr.graph.port_offset(port).unwrap(); + (node.into(), offset.into()) + }) } #[inline] - fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> { + fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { let hugr = self.as_ref(); hugr.graph .get_connections(node.pg_index(), other.pg_index()) - .with_context(hugr) - .map_with_context(|(p1, p2), hugr| { - [p1, p2].map(|link| { - let offset = hugr.graph.port_offset(link.port()).unwrap(); - offset.into() - }) + .map(|(p1, p2)| { + [p1, p2].map(|link| hugr.graph.port_offset(link.port()).unwrap().into()) }) } @@ -650,12 +587,12 @@ impl> HugrView for T { } #[inline] - fn children(&self, node: Node) -> Self::Children<'_> { + fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone { self.as_ref().hierarchy.children(node.pg_index()).map_into() } #[inline] - fn neighbours(&self, node: Node, dir: Direction) -> Self::Neighbours<'_> { + fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.as_ref() .graph .neighbours(node.pg_index(), dir) @@ -663,7 +600,7 @@ impl> HugrView for T { } #[inline] - fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_> { + fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { self.as_ref() .graph .all_neighbours(node.pg_index()) diff --git a/hugr-core/src/hugr/views/descendants.rs b/hugr-core/src/hugr/views/descendants.rs index 83a7d6687..61db536ef 100644 --- a/hugr-core/src/hugr/views/descendants.rs +++ b/hugr-core/src/hugr/views/descendants.rs @@ -1,8 +1,7 @@ //! DescendantsGraph: view onto the subgraph of the HUGR starting from a root //! (all descendants at all depths). -use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx}; -use itertools::{Itertools, MapInto}; +use itertools::Itertools; use portgraph::{LinkView, MultiPortGraph, PortIndex, PortView}; use crate::hugr::HugrError; @@ -40,36 +39,6 @@ pub struct DescendantsGraph<'g, Root = Node> { _phantom: std::marker::PhantomData, } impl<'g, Root: NodeHandle> HugrView for DescendantsGraph<'g, Root> { - type Nodes<'a> = MapInto< as PortView>::Nodes<'a>, Node> - where - Self: 'a; - - type NodePorts<'a> = MapInto< as PortView>::NodePortOffsets<'a>, Port> - where - Self: 'a; - - type Children<'a> = MapInto, Node> - where - Self: 'a; - - type Neighbours<'a> = MapInto< as LinkView>::Neighbours<'a>, Node> - where - Self: 'a; - - type PortLinks<'a> = MapWithCtx< - as LinkView>::PortLinks<'a>, - &'a Self, - (Node, Port), - > where - Self: 'a; - - type NodeConnections<'a> = MapWithCtx< - as LinkView>::NodeConnections<'a>, - &'a Self, - [Port; 2], - > where - Self: 'a; - #[inline] fn contains_node(&self, node: Node) -> bool { self.graph.contains_node(node.pg_index()) @@ -86,43 +55,43 @@ impl<'g, Root: NodeHandle> HugrView for DescendantsGraph<'g, Root> { } #[inline] - fn nodes(&self) -> Self::Nodes<'_> { + fn nodes(&self) -> impl Iterator + Clone { self.graph.nodes_iter().map_into() } #[inline] - fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_> { + fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph.port_offsets(node.pg_index(), dir).map_into() } #[inline] - fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_> { + fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { self.graph.all_port_offsets(node.pg_index()).map_into() } - fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_> { + fn linked_ports( + &self, + node: Node, + port: impl Into, + ) -> impl Iterator + Clone { let port = self .graph .port_index(node.pg_index(), port.into().pg_offset()) .unwrap(); - self.graph - .port_links(port) - .with_context(self) - .map_with_context(|(_, link), region| { - let port: PortIndex = link.into(); - let node = region.graph.port_node(port).unwrap(); - let offset = region.graph.port_offset(port).unwrap(); - (node.into(), offset.into()) - }) + self.graph.port_links(port).map(|(_, link)| { + let port: PortIndex = link.into(); + let node = self.graph.port_node(port).unwrap(); + let offset = self.graph.port_offset(port).unwrap(); + (node.into(), offset.into()) + }) } - fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> { + fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { self.graph .get_connections(node.pg_index(), other.pg_index()) - .with_context(self) - .map_with_context(|(p1, p2), hugr| { + .map(|(p1, p2)| { [p1, p2].map(|link| { - let offset = hugr.graph.port_offset(link).unwrap(); + let offset = self.graph.port_offset(link).unwrap(); offset.into() }) }) @@ -134,7 +103,7 @@ impl<'g, Root: NodeHandle> HugrView for DescendantsGraph<'g, Root> { } #[inline] - fn children(&self, node: Node) -> Self::Children<'_> { + fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone { match self.graph.contains_node(node.pg_index()) { true => self .base_hugr() @@ -146,12 +115,12 @@ impl<'g, Root: NodeHandle> HugrView for DescendantsGraph<'g, Root> { } #[inline] - fn neighbours(&self, node: Node, dir: Direction) -> Self::Neighbours<'_> { + fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph.neighbours(node.pg_index(), dir).map_into() } #[inline] - fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_> { + fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { self.graph.all_neighbours(node.pg_index()).map_into() } } @@ -204,6 +173,7 @@ pub(super) mod test { use rstest::rstest; use crate::extension::PRELUDE_REGISTRY; + use crate::IncomingPort; use crate::{ builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, type_row, @@ -253,6 +223,7 @@ pub(super) mod test { let (hugr, def, inner) = make_module_hgr()?; let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; + let def_io = region.get_io(def).unwrap(); assert_eq!(region.node_count(), 7); assert!(region.nodes().all(|n| n == def @@ -268,11 +239,48 @@ pub(super) mod test { .into() ) ); + let inner_region: DescendantsGraph = DescendantsGraph::try_new(&hugr, inner)?; assert_eq!( inner_region.inner_function_type(), Some(Signature::new(type_row![NAT], type_row![NAT])) ); + assert_eq!(inner_region.node_count(), 3); + assert_eq!(inner_region.edge_count(), 2); + assert_eq!(inner_region.children(inner).count(), 2); + assert_eq!(inner_region.children(hugr.root()).count(), 0); + assert_eq!( + inner_region.num_ports(inner, Direction::Outgoing), + inner_region.node_ports(inner, Direction::Outgoing).count() + ); + assert_eq!( + inner_region.num_ports(inner, Direction::Incoming) + + inner_region.num_ports(inner, Direction::Outgoing), + inner_region.all_node_ports(inner).count() + ); + + // The inner region filters out the connections to the main function I/O nodes, + // while the outer region includes them. + assert_eq!(inner_region.node_connections(inner, def_io[1]).count(), 0); + assert_eq!(region.node_connections(inner, def_io[1]).count(), 1); + assert_eq!( + inner_region + .linked_ports(inner, IncomingPort::from(0)) + .count(), + 0 + ); + assert_eq!(region.linked_ports(inner, IncomingPort::from(0)).count(), 1); + assert_eq!( + inner_region.neighbours(inner, Direction::Outgoing).count(), + 0 + ); + assert_eq!(inner_region.all_neighbours(inner).count(), 0); + assert_eq!( + inner_region + .linked_ports(inner, IncomingPort::from(0)) + .count(), + 0 + ); Ok(()) } diff --git a/hugr-core/src/hugr/views/petgraph.rs b/hugr-core/src/hugr/views/petgraph.rs index 9ae2a2331..0f909b332 100644 --- a/hugr-core/src/hugr/views/petgraph.rs +++ b/hugr-core/src/hugr/views/petgraph.rs @@ -6,7 +6,6 @@ use crate::types::EdgeKind; use crate::NodeIndex; use crate::{Node, Port}; -use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx}; use petgraph::visit as pv; /// Wrapper for a HugrView that implements petgraph's traits. @@ -99,13 +98,14 @@ where T: HugrView, { type NodeRef = HugrNodeRef<'a>; - type NodeReferences = MapWithCtx<::Nodes<'a>, Self, HugrNodeRef<'a>>; + type NodeReferences = Box> + 'a>; fn node_references(self) -> Self::NodeReferences { - self.hugr - .nodes() - .with_context(self) - .map_with_context(|n, &wrapper| HugrNodeRef::from_node(n, wrapper.hugr)) + Box::new( + self.hugr + .nodes() + .map(|n| HugrNodeRef::from_node(n, self.hugr)), + ) } } @@ -113,10 +113,10 @@ impl<'a, T> pv::IntoNodeIdentifiers for PetgraphWrapper<'a, T> where T: HugrView, { - type NodeIdentifiers = ::Nodes<'a>; + type NodeIdentifiers = Box + 'a>; fn node_identifiers(self) -> Self::NodeIdentifiers { - self.hugr.nodes() + Box::new(self.hugr.nodes()) } } @@ -124,10 +124,10 @@ impl<'a, T> pv::IntoNeighbors for PetgraphWrapper<'a, T> where T: HugrView, { - type Neighbors = ::Neighbours<'a>; + type Neighbors = Box + 'a>; fn neighbors(self, n: Self::NodeId) -> Self::Neighbors { - self.hugr.output_neighbours(n) + Box::new(self.hugr.output_neighbours(n)) } } @@ -135,14 +135,14 @@ impl<'a, T> pv::IntoNeighborsDirected for PetgraphWrapper<'a, T> where T: HugrView, { - type NeighborsDirected = ::Neighbours<'a>; + type NeighborsDirected = Box + 'a>; fn neighbors_directed( self, n: Self::NodeId, d: petgraph::Direction, ) -> Self::NeighborsDirected { - self.hugr.neighbours(n, d.into()) + Box::new(self.hugr.neighbours(n, d.into())) } } @@ -211,3 +211,39 @@ impl pv::NodeRef for HugrNodeRef<'_> { self.op } } + +#[cfg(test)] +mod test { + use petgraph::visit::{ + EdgeCount, GetAdjacencyMatrix, IntoNodeReferences, NodeCount, NodeIndexable, NodeRef, + }; + + use crate::hugr::views::tests::sample_hugr; + use crate::ops::handle::NodeHandle; + use crate::HugrView; + + use super::PetgraphWrapper; + + #[test] + fn test_petgraph_wrapper() { + let (hugr, cx1, cx2) = sample_hugr(); + let wrapper = PetgraphWrapper::from(&hugr); + + assert_eq!(wrapper.node_count(), 5); + assert_eq!(wrapper.node_bound(), 5); + assert_eq!(wrapper.edge_count(), 7); + + let cx1_index = cx1.node().pg_index().index(); + assert_eq!(wrapper.to_index(cx1.node()), cx1_index); + assert_eq!(wrapper.from_index(cx1_index), cx1.node()); + + let cx1_ref = wrapper + .node_references() + .find(|n| n.id() == cx1.node()) + .unwrap(); + assert_eq!(cx1_ref.weight(), hugr.get_optype(cx1.node())); + + let adj = wrapper.adjacency_matrix(); + assert!(wrapper.is_adjacent(&adj, cx1.node(), cx2.node())); + } +} diff --git a/hugr-core/src/hugr/views/sibling.rs b/hugr-core/src/hugr/views/sibling.rs index 1125bad25..07aaba1da 100644 --- a/hugr-core/src/hugr/views/sibling.rs +++ b/hugr-core/src/hugr/views/sibling.rs @@ -2,9 +2,8 @@ use std::iter; -use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx}; -use itertools::{Itertools, MapInto}; -use portgraph::{LinkView, MultiPortGraph, PortIndex, PortView}; +use itertools::{Either, Itertools}; +use portgraph::{LinkView, MultiPortGraph, PortView}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::{HugrError, HugrMut}; @@ -48,19 +47,6 @@ pub struct SiblingGraph<'g, Root = Node> { /// i.e. that rely only on [HugrInternals::base_hugr] macro_rules! impl_base_members { () => { - - type Nodes<'a> = iter::Chain, MapInto, Node>> - where - Self: 'a; - - type NodePorts<'a> = MapInto< as PortView>::NodePortOffsets<'a>, Port> - where - Self: 'a; - - type Children<'a> = MapInto, Node> - where - Self: 'a; - #[inline] fn node_count(&self) -> usize { self.base_hugr().hierarchy.child_count(self.root.pg_index()) + 1 @@ -75,7 +61,7 @@ macro_rules! impl_base_members { } #[inline] - fn nodes(&self) -> Self::Nodes<'_> { + fn nodes(&self) -> impl Iterator + Clone { // Faster implementation than filtering all the nodes in the internal graph. let children = self .base_hugr() @@ -85,10 +71,14 @@ macro_rules! impl_base_members { iter::once(self.root).chain(children) } - fn children(&self, node: Node) -> Self::Children<'_> { + fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone { // Same as SiblingGraph match node == self.root { - true => self.base_hugr().hierarchy.children(node.pg_index()).map_into(), + true => self + .base_hugr() + .hierarchy + .children(node.pg_index()) + .map_into(), false => portgraph::hierarchy::Children::default().map_into(), } } @@ -96,24 +86,6 @@ macro_rules! impl_base_members { } impl<'g, Root: NodeHandle> HugrView for SiblingGraph<'g, Root> { - type Neighbours<'a> = MapInto< as LinkView>::Neighbours<'a>, Node> - where - Self: 'a; - - type PortLinks<'a> = MapWithCtx< - as LinkView>::PortLinks<'a>, - &'a Self, - (Node, Port), - > where - Self: 'a; - - type NodeConnections<'a> = MapWithCtx< - as LinkView>::NodeConnections<'a>, - &'a Self, - [Port; 2], - > where - Self: 'a; - impl_base_members! {} #[inline] @@ -122,41 +94,35 @@ impl<'g, Root: NodeHandle> HugrView for SiblingGraph<'g, Root> { } #[inline] - fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_> { + fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph.port_offsets(node.pg_index(), dir).map_into() } #[inline] - fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_> { + fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { self.graph.all_port_offsets(node.pg_index()).map_into() } - fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_> { + fn linked_ports( + &self, + node: Node, + port: impl Into, + ) -> impl Iterator + Clone { let port = self .graph .port_index(node.pg_index(), port.into().pg_offset()) .unwrap(); - self.graph - .port_links(port) - .with_context(self) - .map_with_context(|(_, link), region| { - let port: PortIndex = link.into(); - let node = region.graph.port_node(port).unwrap(); - let offset = region.graph.port_offset(port).unwrap(); - (node.into(), offset.into()) - }) + self.graph.port_links(port).map(|(_, link)| { + let node = self.graph.port_node(link).unwrap(); + let offset = self.graph.port_offset(link).unwrap(); + (node.into(), offset.into()) + }) } - fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> { + fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { self.graph .get_connections(node.pg_index(), other.pg_index()) - .with_context(self) - .map_with_context(|(p1, p2), hugr| { - [p1, p2].map(|link| { - let offset = hugr.graph.port_offset(link).unwrap(); - offset.into() - }) - }) + .map(|(p1, p2)| [p1, p2].map(|link| self.graph.port_offset(link).unwrap().into())) } #[inline] @@ -165,12 +131,12 @@ impl<'g, Root: NodeHandle> HugrView for SiblingGraph<'g, Root> { } #[inline] - fn neighbours(&self, node: Node, dir: Direction) -> Self::Neighbours<'_> { + fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph.neighbours(node.pg_index(), dir).map_into() } #[inline] - fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_> { + fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { self.graph.all_neighbours(node.pg_index()).map_into() } } @@ -293,16 +259,6 @@ impl<'g, Root: NodeHandle> HugrInternals for SiblingMut<'g, Root> { } impl<'g, Root: NodeHandle> HugrView for SiblingMut<'g, Root> { - type Neighbours<'a> = as IntoIterator>::IntoIter - where - Self: 'a; - - type PortLinks<'a> = as IntoIterator>::IntoIter - where - Self: 'a; - - type NodeConnections<'a> = as IntoIterator>::IntoIter where Self: 'a; - impl_base_members! {} fn contains_node(&self, node: Node) -> bool { @@ -311,61 +267,54 @@ impl<'g, Root: NodeHandle> HugrView for SiblingMut<'g, Root> { node == self.root || self.base_hugr().get_parent(node) == Some(self.root) } - fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_> { - match self.contains_node(node) { - true => self.base_hugr().node_ports(node, dir), - false => ::NodePortOffsets::default().map_into(), - } + fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { + self.base_hugr().node_ports(node, dir) } - fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_> { - match self.contains_node(node) { - true => self.base_hugr().all_node_ports(node), - false => ::NodePortOffsets::default().map_into(), - } + fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { + self.base_hugr().all_node_ports(node) } - fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_> { - // Need to filter only to links inside the sibling graph - SiblingGraph::<'_, Node>::new_unchecked(self.hugr, self.root) + fn linked_ports( + &self, + node: Node, + port: impl Into, + ) -> impl Iterator + Clone { + self.hugr .linked_ports(node, port) - .collect::>() - .into_iter() + .filter(|(n, _)| self.contains_node(*n)) } - fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> { - // Need to filter only to connections inside the sibling graph - SiblingGraph::<'_, Node>::new_unchecked(self.hugr, self.root) - .node_connections(node, other) - .collect::>() - .into_iter() + fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { + match self.contains_node(node) && self.contains_node(other) { + // The nodes are not in the sibling graph + false => Either::Left(iter::empty()), + // The nodes are in the sibling graph + true => Either::Right(self.hugr.node_connections(node, other)), + } } fn num_ports(&self, node: Node, dir: Direction) -> usize { - match self.contains_node(node) { - true => self.base_hugr().num_ports(node, dir), - false => 0, - } + self.base_hugr().num_ports(node, dir) } - fn neighbours(&self, node: Node, dir: Direction) -> Self::Neighbours<'_> { - // Need to filter to neighbours in the Sibling Graph - SiblingGraph::<'_, Node>::new_unchecked(self.hugr, self.root) + fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { + self.hugr .neighbours(node, dir) - .collect::>() - .into_iter() + .filter(|n| self.contains_node(*n)) } - fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_> { - SiblingGraph::<'_, Node>::new_unchecked(self.hugr, self.root) + fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { + self.hugr .all_neighbours(node) - .collect::>() - .into_iter() + .filter(|n| self.contains_node(*n)) } } + impl<'g, Root: NodeHandle> RootTagged for SiblingMut<'g, Root> { type RootHandle = Root; } + impl<'g, Root: NodeHandle> HugrMutInternals for SiblingMut<'g, Root> { fn hugr_mut(&mut self) -> &mut Hugr { self.hugr @@ -384,28 +333,116 @@ mod test { use crate::ops::handle::{CfgID, DataflowParentID, DfgID, FuncID}; use crate::ops::{dataflow::IOTrait, Input, OpTag, Output}; use crate::ops::{OpTrait, OpType}; - use crate::type_row; use crate::types::{Signature, Type}; + use crate::utils::test_quantum_extension::EXTENSION_ID; + use crate::{type_row, IncomingPort}; + + const NAT: Type = crate::extension::prelude::USIZE_T; + const QB: Type = crate::extension::prelude::QB_T; use super::super::descendants::test::make_module_hgr; use super::*; - #[test] - fn flat_region() -> Result<(), Box> { - let (hugr, def, inner) = make_module_hgr()?; - - let region: SiblingGraph = SiblingGraph::try_new(&hugr, def)?; + fn test_properties( + hugr: &Hugr, + def: Node, + inner: Node, + region: T, + inner_region: T, + ) -> Result<(), Box> + where + T: HugrView + Sized, + { + let def_io = region.get_io(def).unwrap(); assert_eq!(region.node_count(), 5); - assert!(region - .nodes() - .all(|n| n == def || hugr.get_parent(n) == Some(def))); + assert_eq!(region.portgraph().node_count(), 5); + assert!(region.nodes().all(|n| n == def + || hugr.get_parent(n) == Some(def) + || hugr.get_parent(n) == Some(inner))); assert_eq!(region.children(inner).count(), 0); + assert_eq!( + region.poly_func_type(), + Some( + Signature::new_endo(type_row![NAT, QB]) + .with_extension_delta(EXTENSION_ID) + .into() + ) + ); + + assert_eq!( + inner_region.inner_function_type(), + Some(Signature::new(type_row![NAT], type_row![NAT])) + ); + assert_eq!(inner_region.node_count(), 3); + assert_eq!(inner_region.edge_count(), 1); + assert_eq!(inner_region.children(inner).count(), 2); + assert_eq!(inner_region.children(hugr.root()).count(), 0); + assert_eq!( + inner_region.num_ports(inner, Direction::Outgoing), + inner_region.node_ports(inner, Direction::Outgoing).count() + ); + assert_eq!( + inner_region.num_ports(inner, Direction::Incoming) + + inner_region.num_ports(inner, Direction::Outgoing), + inner_region.all_node_ports(inner).count() + ); + + // The inner region filters out the connections to the main function I/O nodes, + // while the outer region includes them. + assert_eq!(inner_region.node_connections(inner, def_io[1]).count(), 0); + assert_eq!(region.node_connections(inner, def_io[1]).count(), 1); + assert_eq!( + inner_region + .linked_ports(inner, IncomingPort::from(0)) + .count(), + 0 + ); + assert_eq!(region.linked_ports(inner, IncomingPort::from(0)).count(), 1); + assert_eq!( + inner_region.neighbours(inner, Direction::Outgoing).count(), + 0 + ); + assert_eq!(inner_region.all_neighbours(inner).count(), 0); + assert_eq!( + inner_region + .linked_ports(inner, IncomingPort::from(0)) + .count(), + 0 + ); + Ok(()) } - const NAT: Type = crate::extension::prelude::USIZE_T; + #[rstest] + fn sibling_graph_properties() -> Result<(), Box> { + let (hugr, def, inner) = make_module_hgr()?; + + test_properties::( + &hugr, + def, + inner, + SiblingGraph::try_new(&hugr, def).unwrap(), + SiblingGraph::try_new(&hugr, inner).unwrap(), + ) + } + + #[rstest] + fn sibling_mut_properties() -> Result<(), Box> { + let (hugr, def, inner) = make_module_hgr()?; + let mut def_region_hugr = hugr.clone(); + let mut inner_region_hugr = hugr.clone(); + + test_properties::( + &hugr, + def, + inner, + SiblingMut::try_new(&mut def_region_hugr, def).unwrap(), + SiblingMut::try_new(&mut inner_region_hugr, inner).unwrap(), + ) + } + #[test] fn nested_flat() -> Result<(), Box> { let mut module_builder = ModuleBuilder::new(); @@ -417,11 +454,13 @@ mod test { let fun = fbuild.finish_with_outputs(sub_dfg.outputs())?; let h = module_builder.finish_hugr(&PRELUDE_REGISTRY)?; let sub_dfg = sub_dfg.node(); - // Can create a view from a child or grandchild of a hugr: + + // We can create a view from a child or grandchild of a hugr: let dfg_view: SiblingGraph<'_, DfgID> = SiblingGraph::try_new(&h, sub_dfg)?; let fun_view: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&h, fun.node())?; - assert_eq!(fun_view.children(sub_dfg).len(), 0); - // And can create a view from a child of another SiblingGraph + assert_eq!(fun_view.children(sub_dfg).count(), 0); + + // And also create a view from a child of another SiblingGraph let nested_dfg_view: SiblingGraph<'_, DfgID> = SiblingGraph::try_new(&fun_view, sub_dfg)?; // Both ways work: @@ -439,6 +478,7 @@ mod test { Ok(()) } + /// Mutate a SiblingMut wrapper #[rstest] fn flat_mut(mut simple_dfg_hugr: Hugr) { simple_dfg_hugr.update_validate(&PRELUDE_REGISTRY).unwrap(); diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index 167d76419..4fcebe179 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -211,7 +211,7 @@ impl SiblingSubgraph { /// The subgraph signature will be given by the types of the incoming and /// outgoing edges ordered by the node order in `nodes` and within each node /// by the port order. - + /// /// The in- and out-arity of the signature will match the /// number of incoming and outgoing edges respectively. In particular, the /// assumption is made that no two incoming edges have the same source @@ -238,6 +238,14 @@ impl SiblingSubgraph { checker: &impl ConvexChecker, ) -> Result { let nodes = nodes.into(); + + // If there's one or less nodes, we don't need to check convexity. + match nodes.as_slice() { + [] => return Err(InvalidSubgraph::EmptySubgraph), + [node] => return Ok(Self::from_node(*node, hugr)), + _ => {} + }; + let nodes_set = nodes.iter().copied().collect::>(); let incoming_edges = nodes .iter() @@ -265,6 +273,31 @@ impl SiblingSubgraph { Self::try_new_with_checker(inputs, outputs, hugr, checker) } + /// Create a subgraph containing a single node. + /// + /// The subgraph signature will be given by signature of the node. + pub fn from_node(node: Node, hugr: &impl HugrView) -> Self { + // TODO once https://github.com/CQCL/portgraph/issues/155 + // is fixed we can just call try_from_nodes here. + // Until then, doing this saves a lot of work. + let nodes = vec![node]; + let inputs = hugr + .node_inputs(node) + .filter(|&p| hugr.is_linked(node, p)) + .map(|p| vec![(node, p)]) + .collect_vec(); + let outputs = hugr + .node_outputs(node) + .filter_map(|p| hugr.is_linked(node, p).then_some((node, p))) + .collect_vec(); + + Self { + nodes, + inputs, + outputs, + } + } + /// An iterator over the nodes in the subgraph. pub fn nodes(&self) -> &[Node] { &self.nodes @@ -424,15 +457,21 @@ impl SiblingSubgraph { // Connect the inserted nodes in-between the input and output nodes. let [inp, out] = extracted.get_io(extracted.root()).unwrap(); - for (inp_port, repl_ports) in extracted.node_outputs(inp).zip(self.inputs.iter()) { + let inputs = extracted.node_outputs(inp).zip(self.inputs.iter()); + let outputs = extracted.node_inputs(out).zip(self.outputs.iter()); + let mut connections = Vec::with_capacity(inputs.size_hint().0 + outputs.size_hint().0); + + for (inp_port, repl_ports) in inputs { for (repl_node, repl_port) in repl_ports { - extracted.connect(inp, inp_port, node_map[repl_node], *repl_port); + connections.push((inp, inp_port, node_map[repl_node], *repl_port)); } } - for (out_port, (repl_node, repl_port)) in - extracted.node_inputs(out).zip(self.outputs.iter()) - { - extracted.connect(node_map[repl_node], *repl_port, out, out_port); + for (out_port, (repl_node, repl_port)) in outputs { + connections.push((node_map[repl_node], *repl_port, out, out_port)); + } + + for (src, src_port, dst, dst_port) in connections { + extracted.connect(src, src_port, dst, dst_port); } extracted @@ -1030,9 +1069,9 @@ mod tests { let (hugr, func_root) = build_3not_hugr().unwrap(); let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap(); let [inp, _out] = hugr.get_io(func_root).unwrap(); - let not1 = hugr.output_neighbours(inp).exactly_one().unwrap(); - let not2 = hugr.output_neighbours(not1).exactly_one().unwrap(); - let not3 = hugr.output_neighbours(not2).exactly_one().unwrap(); + let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap(); + let not2 = hugr.output_neighbours(not1).exactly_one().ok().unwrap(); + let not3 = hugr.output_neighbours(not2).exactly_one().ok().unwrap(); let not1_inp = hugr.node_inputs(not1).next().unwrap(); let not1_out = hugr.node_outputs(not1).next().unwrap(); let not3_inp = hugr.node_inputs(not3).next().unwrap(); @@ -1053,11 +1092,12 @@ mod tests { fn convex_multiports() { let (hugr, func_root) = build_multiport_hugr().unwrap(); let [inp, out] = hugr.get_io(func_root).unwrap(); - let not1 = hugr.output_neighbours(inp).exactly_one().unwrap(); + let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap(); let not2 = hugr .output_neighbours(not1) .filter(|&n| n != out) .exactly_one() + .ok() .unwrap(); let subgraph = SiblingSubgraph::try_from_nodes([not1, not2], &hugr).unwrap(); diff --git a/hugr-core/src/hugr/views/tests.rs b/hugr-core/src/hugr/views/tests.rs index c52957b2c..c958c8b9f 100644 --- a/hugr-core/src/hugr/views/tests.rs +++ b/hugr-core/src/hugr/views/tests.rs @@ -17,8 +17,11 @@ use crate::{ Hugr, HugrView, }; +/// A Dataflow graph from two qubits to two qubits that applies two CX operations on them. +/// +/// Returns the Hugr and the two CX node ids. #[fixture] -fn sample_hugr() -> (Hugr, BuildHandle, BuildHandle) { +pub(crate) fn sample_hugr() -> (Hugr, BuildHandle, BuildHandle) { let mut dfg = DFGBuilder::new(endo_sig(type_row![QB_T, QB_T])).unwrap(); let [q1, q2] = dfg.input_wires_arr(); @@ -99,6 +102,7 @@ fn all_ports(sample_hugr: (Hugr, BuildHandle, BuildHandle, BuildHandle { /// A map from `NodeId` to the imported `Node`. nodes: FxHashMap, - /// The types of the local variables that are currently in scope. - local_variables: FxIndexMap<&'a str, model::TermId>, + /// The local variables that are currently in scope. + local_variables: FxIndexMap<&'a str, LocalVar>, custom_name_cache: FxHashMap<&'a str, (ExtensionId, SmolStr)>, } @@ -155,20 +155,20 @@ impl<'a> Context<'a> { .ok_or_else(|| model::ModelError::RegionNotFound(region_id).into()) } - /// Looks up a [`LocalRef`] within the current scope and returns its index and type. + /// Looks up a [`LocalRef`] within the current scope. fn resolve_local_ref( &self, local_ref: &model::LocalRef, - ) -> Result<(usize, model::TermId), ImportError> { + ) -> Result<(usize, LocalVar), ImportError> { let term = match local_ref { model::LocalRef::Index(_, index) => self .local_variables .get_index(*index as usize) - .map(|(_, term)| (*index as usize, *term)), + .map(|(_, v)| (*index as usize, *v)), model::LocalRef::Named(name) => self .local_variables .get_full(name) - .map(|(index, _, term)| (index, *term)), + .map(|(index, _, v)| (index, *v)), }; term.ok_or_else(|| model::ModelError::InvalidLocal(local_ref.to_string()).into()) @@ -810,9 +810,11 @@ impl<'a> Context<'a> { // Connect the input node to the tag node let input_outputs = self.hugr.node_outputs(node_input); let tag_inputs = self.hugr.node_inputs(node_tag); + let mut connections = + Vec::with_capacity(input_outputs.size_hint().0 + tag_inputs.size_hint().0); for (a, b) in input_outputs.zip(tag_inputs) { - self.hugr.connect(node_input, a, node_tag, b); + connections.push((node_input, a, node_tag, b)); } // Connect the tag node to the output node @@ -820,7 +822,11 @@ impl<'a> Context<'a> { let output_inputs = self.hugr.node_inputs(node_output); for (a, b) in tag_outputs.zip(output_inputs) { - self.hugr.connect(node_tag, a, node_output, b); + connections.push((node_tag, a, node_output, b)); + } + + for (src, src_port, dst, dst_port) in connections { + self.hugr.connect(src, src_port, dst, dst_port); } } @@ -892,41 +898,49 @@ impl<'a> Context<'a> { self.with_local_socpe(|ctx| { let mut imported_params = Vec::with_capacity(decl.params.len()); - for param in decl.params { - // TODO: `PolyFuncType` should be able to handle constraints - // and distinguish between implicit and explicit parameters. - match param { - model::Param::Implicit { name, r#type } => { - imported_params.push(ctx.import_type_param(*r#type)?); - ctx.local_variables.insert(name, *r#type); - } - model::Param::Explicit { name, r#type } => { - imported_params.push(ctx.import_type_param(*r#type)?); - ctx.local_variables.insert(name, *r#type); - } - model::Param::Constraint { constraint: _ } => { - return Err(error_unsupported!("constraints")); + ctx.local_variables.extend( + decl.params + .iter() + .map(|param| (param.name, LocalVar::new(param.r#type))), + ); + + for constraint in decl.constraints { + match ctx.get_term(*constraint)? { + model::Term::NonLinearConstraint { term } => { + let model::Term::Var(var) = ctx.get_term(*term)? else { + return Err(error_unsupported!( + "constraint on term that is not a variable" + )); + }; + + let var = ctx.resolve_local_ref(var)?.0; + ctx.local_variables[var].bound = TypeBound::Copyable; } + _ => return Err(error_unsupported!("constraint other than copy or discard")), } } + for (index, param) in decl.params.iter().enumerate() { + // NOTE: `PolyFuncType` only has explicit type parameters at present. + let bound = ctx.local_variables[index].bound; + imported_params.push(ctx.import_type_param(param.r#type, bound)?); + } + let body = ctx.import_func_type::(decl.signature)?; in_scope(ctx, PolyFuncTypeBase::new(imported_params, body)) }) } /// Import a [`TypeParam`] from a term that represents a static type. - fn import_type_param(&mut self, term_id: model::TermId) -> Result { + fn import_type_param( + &mut self, + term_id: model::TermId, + bound: TypeBound, + ) -> Result { match self.get_term(term_id)? { model::Term::Wildcard => Err(error_uninferred!("wildcard")), - model::Term::Type => { - // As part of the migration from `TypeBound`s to constraints, we pretend that all - // `TypeBound`s are copyable. - Ok(TypeParam::Type { - b: TypeBound::Copyable, - }) - } + model::Term::Type => Ok(TypeParam::Type { b: bound }), model::Term::StaticType => Err(error_unsupported!("`type` as `TypeParam`")), model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeParam`")), @@ -938,7 +952,9 @@ impl<'a> Context<'a> { model::Term::FuncType { .. } => Err(error_unsupported!("`(fn ...)` as `TypeParam`")), model::Term::ListType { item_type } => { - let param = Box::new(self.import_type_param(*item_type)?); + // At present `hugr-model` has no way to express that the item + // type of a list must be copyable. Therefore we import it as `Any`. + let param = Box::new(self.import_type_param(*item_type, TypeBound::Any)?); Ok(TypeParam::List { param }) } @@ -952,7 +968,10 @@ impl<'a> Context<'a> { | model::Term::List { .. } | model::Term::ExtSet { .. } | model::Term::Adt { .. } - | model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Control { .. } + | model::Term::NonLinearConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } model::Term::ControlType => { Err(error_unsupported!("type of control types as `TypeParam`")) @@ -960,7 +979,7 @@ impl<'a> Context<'a> { } } - /// Import a `TypeArg` froma term that represents a static type or value. + /// Import a `TypeArg` from a term that represents a static type or value. fn import_type_arg(&mut self, term_id: model::TermId) -> Result { match self.get_term(term_id)? { model::Term::Wildcard => Err(error_uninferred!("wildcard")), @@ -969,8 +988,8 @@ impl<'a> Context<'a> { } model::Term::Var(var) => { - let (index, var_type) = self.resolve_local_ref(var)?; - let decl = self.import_type_param(var_type)?; + let (index, var) = self.resolve_local_ref(var)?; + let decl = self.import_type_param(var.r#type, var.bound)?; Ok(TypeArg::new_var_use(index, decl)) } @@ -1008,7 +1027,10 @@ impl<'a> Context<'a> { model::Term::FuncType { .. } | model::Term::Adt { .. } - | model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Control { .. } + | model::Term::NonLinearConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } } } @@ -1109,7 +1131,10 @@ impl<'a> Context<'a> { | model::Term::List { .. } | model::Term::Control { .. } | model::Term::ControlType - | model::Term::Nat(_) => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Nat(_) + | model::Term::NonLinearConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } } } @@ -1285,3 +1310,21 @@ impl<'a> Names<'a> { Ok(Self { items }) } } + +/// Information about a local variable. +#[derive(Debug, Clone, Copy)] +struct LocalVar { + /// The type of the variable. + r#type: model::TermId, + /// The type bound of the variable. + bound: TypeBound, +} + +impl LocalVar { + pub fn new(r#type: model::TermId) -> Self { + Self { + r#type, + bound: TypeBound::Any, + } + } +} diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs index 611eda660..d9ef0d2c9 100644 --- a/hugr-core/tests/model.rs +++ b/hugr-core/tests/model.rs @@ -58,3 +58,10 @@ pub fn test_roundtrip_params() { "../../hugr-model/tests/fixtures/model-params.edn" ))); } + +#[test] +pub fn test_roundtrip_constraints() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-constraints.edn" + ))); +} diff --git a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap new file mode 100644 index 000000000..f085c4785 --- /dev/null +++ b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap @@ -0,0 +1,16 @@ +--- +source: hugr-core/tests/model.rs +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-constraints.edn\"))" +--- +(hugr 0) + +(declare-func array.replicate + (forall ?0 type) + (forall ?1 nat) + (where (nonlinear ?0)) + [?0] [(@ array.Array ?0 ?1)] (ext)) + +(declare-func array.copy + (forall ?0 type) + (where (nonlinear ?0)) + [(@ array.Array ?0)] [(@ array.Array ?0) (@ array.Array ?0)] (ext)) diff --git a/hugr-model/capnp/hugr-v0.capnp b/hugr-model/capnp/hugr-v0.capnp index 95db81205..94341beba 100644 --- a/hugr-model/capnp/hugr-v0.capnp +++ b/hugr-model/capnp/hugr-v0.capnp @@ -56,13 +56,15 @@ struct Operation { struct FuncDefn { name @0 :Text; params @1 :List(Param); - signature @2 :TermId; + constraints @2 :List(TermId); + signature @3 :TermId; } struct FuncDecl { name @0 :Text; params @1 :List(Param); - signature @2 :TermId; + constraints @2 :List(TermId); + signature @3 :TermId; } struct AliasDefn { @@ -81,13 +83,15 @@ struct Operation { struct ConstructorDecl { name @0 :Text; params @1 :List(Param); - type @2 :TermId; + constraints @2 :List(TermId); + type @3 :TermId; } struct OperationDecl { name @0 :Text; params @1 :List(Param); - type @2 :TermId; + constraints @2 :List(TermId); + type @3 :TermId; } } @@ -157,6 +161,7 @@ struct Term { funcType @17 :FuncType; control @18 :TermId; controlType @19 :Void; + nonLinearConstraint @20 :TermId; } struct Apply { @@ -187,19 +192,12 @@ struct Term { } struct Param { - union { - implicit @0 :Implicit; - explicit @1 :Explicit; - constraint @2 :TermId; - } - - struct Implicit { - name @0 :Text; - type @1 :TermId; - } + name @0 :Text; + type @1 :TermId; + sort @2 :ParamSort; +} - struct Explicit { - name @0 :Text; - type @1 :TermId; - } +enum ParamSort { + implicit @0; + explicit @1; } diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index 681bd4ea9..5381a7dc8 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -140,10 +140,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let signature = model::TermId(reader.get_signature()); let decl = bump.alloc(model::FuncDecl { name, params, + constraints, signature, }); model::Operation::DefineFunc { decl } @@ -152,10 +154,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let signature = model::TermId(reader.get_signature()); let decl = bump.alloc(model::FuncDecl { name, params, + constraints, signature, }); model::Operation::DeclareFunc { decl } @@ -189,10 +193,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let r#type = model::TermId(reader.get_type()); let decl = bump.alloc(model::ConstructorDecl { name, params, + constraints, r#type, }); model::Operation::DeclareConstructor { decl } @@ -201,10 +207,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let r#type = model::TermId(reader.get_type()); let decl = bump.alloc(model::OperationDecl { name, params, + constraints, r#type, }); model::Operation::DeclareOperation { decl } @@ -332,6 +340,10 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult Which::Control(values) => model::Term::Control { values: model::TermId(values), }, + + Which::NonLinearConstraint(term) => model::Term::NonLinearConstraint { + term: model::TermId(term), + }, }) } @@ -348,23 +360,13 @@ fn read_param<'a>( bump: &'a Bump, reader: hugr_capnp::param::Reader, ) -> ReadResult> { - use hugr_capnp::param::Which; - Ok(match reader.which()? { - Which::Implicit(reader) => { - let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); - let r#type = model::TermId(reader.get_type()); - model::Param::Implicit { name, r#type } - } - Which::Explicit(reader) => { - let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); - let r#type = model::TermId(reader.get_type()); - model::Param::Explicit { name, r#type } - } - Which::Constraint(constraint) => { - let constraint = model::TermId(constraint); - model::Param::Constraint { constraint } - } - }) + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let r#type = model::TermId(reader.get_type()); + + let sort = match reader.get_sort()? { + hugr_capnp::ParamSort::Implicit => model::ParamSort::Implicit, + hugr_capnp::ParamSort::Explicit => model::ParamSort::Explicit, + }; + + Ok(model::Param { name, r#type, sort }) } diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index a4b64d646..f3a0a14d2 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -60,12 +60,14 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode let mut builder = builder.init_func_defn(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_signature(decl.signature.0); } model::Operation::DeclareFunc { decl } => { let mut builder = builder.init_func_decl(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_signature(decl.signature.0); } @@ -87,12 +89,14 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode let mut builder = builder.init_constructor_decl(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_type(decl.r#type.0); } model::Operation::DeclareOperation { decl } => { let mut builder = builder.init_operation_decl(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_type(decl.r#type.0); } @@ -101,19 +105,12 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode } fn write_param(mut builder: hugr_capnp::param::Builder, param: &model::Param) { - match param { - model::Param::Implicit { name, r#type } => { - let mut builder = builder.init_implicit(); - builder.set_name(name); - builder.set_type(r#type.0); - } - model::Param::Explicit { name, r#type } => { - let mut builder = builder.init_explicit(); - builder.set_name(name); - builder.set_type(r#type.0); - } - model::Param::Constraint { constraint } => builder.set_constraint(constraint.0), - } + builder.set_name(param.name); + builder.set_type(param.r#type.0); + builder.set_sort(match param.sort { + model::ParamSort::Implicit => hugr_capnp::ParamSort::Implicit, + model::ParamSort::Explicit => hugr_capnp::ParamSort::Explicit, + }); } fn write_global_ref(mut builder: hugr_capnp::global_ref::Builder, global_ref: &model::GlobalRef) { @@ -212,5 +209,9 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { builder.set_outputs(outputs.0); builder.set_extensions(extensions.0); } + + model::Term::NonLinearConstraint { term } => { + builder.set_non_linear_constraint(term.0); + } } } diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index cb8713b32..16c7cb6c6 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -397,6 +397,8 @@ pub struct FuncDecl<'a> { pub name: &'a str, /// The static parameters of the function. pub params: &'a [Param<'a>], + /// The constraints on the static parameters. + pub constraints: &'a [TermId], /// The signature of the function. pub signature: TermId, } @@ -419,6 +421,8 @@ pub struct ConstructorDecl<'a> { pub name: &'a str, /// The static parameters of the constructor. pub params: &'a [Param<'a>], + /// The constraints on the static parameters. + pub constraints: &'a [TermId], /// The type of the constructed term. pub r#type: TermId, } @@ -430,6 +434,8 @@ pub struct OperationDecl<'a> { pub name: &'a str, /// The static parameters of the operation. pub params: &'a [Param<'a>], + /// The constraints on the static parameters. + pub constraints: &'a [TermId], /// The type of the operation. This must be a function type. pub r#type: TermId, } @@ -662,6 +668,12 @@ pub enum Term<'a> { /// /// `ctrl : static` ControlType, + + /// Constraint that requires a runtime type to be copyable and discardable. + NonLinearConstraint { + /// The runtime type that must be copyable and discardable. + term: TermId, + }, } /// A parameter to a function or alias. @@ -669,33 +681,23 @@ pub enum Term<'a> { /// Parameter names must be unique within a parameter list. /// Implicit and explicit parameters share a namespace. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum Param<'a> { - /// An implicit parameter that should be inferred, unless a full application form is used +pub struct Param<'a> { + /// The name of the parameter. + pub name: &'a str, + /// The type of the parameter. + pub r#type: TermId, + /// The sort of the parameter (implicit or explicit). + pub sort: ParamSort, +} + +/// The sort of a parameter (implicit or explicit). +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ParamSort { + /// The parameter is implicit and should be inferred, unless a full application form is used /// (see [`Term::ApplyFull`] and [`Operation::CustomFull`]). - Implicit { - /// The name of the parameter. - name: &'a str, - /// The type of the parameter. - /// - /// This must be a term of type `static`. - r#type: TermId, - }, - /// An explicit parameter that should always be provided. - Explicit { - /// The name of the parameter. - name: &'a str, - /// The type of the parameter. - /// - /// This must be a term of type `static`. - r#type: TermId, - }, - /// A constraint that should be satisfied by other parameters in a parameter list. - Constraint { - /// The constraint to be satisfied. - /// - /// This must be a term of type `constraint`. - constraint: TermId, - }, + Implicit, + /// The parameter is explicit and should always be provided. + Explicit, } /// Errors that can occur when traversing and interpreting the model. diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index 132d78567..d05e3d774 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -56,16 +56,16 @@ node_tag = { "(" ~ "tag" ~ tag ~ port_lists? ~ signature? ~ meta* node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } signature = { "(" ~ "signature" ~ term ~ ")" } -func_header = { symbol ~ param* ~ term ~ term ~ term } +func_header = { symbol ~ param* ~ where_clause* ~ term ~ term ~ term } alias_header = { symbol ~ param* ~ term } -ctr_header = { symbol ~ param* ~ term } -operation_header = { symbol ~ param* ~ term } +ctr_header = { symbol ~ param* ~ where_clause* ~ term } +operation_header = { symbol ~ param* ~ where_clause* ~ term } -param = { param_implicit | param_explicit | param_constraint } +param = { param_implicit | param_explicit } -param_implicit = { "(" ~ "forall" ~ term_var ~ term ~ ")" } -param_explicit = { "(" ~ "param" ~ term_var ~ term ~ ")" } -param_constraint = { "(" ~ "where" ~ term ~ ")" } +param_implicit = { "(" ~ "forall" ~ term_var ~ term ~ ")" } +param_explicit = { "(" ~ "param" ~ term_var ~ term ~ ")" } +where_clause = { "(" ~ "where" ~ term ~ ")" } region = { region_dfg | region_cfg } region_dfg = { "(" ~ "dfg" ~ port_lists? ~ signature? ~ meta* ~ node* ~ ")" } @@ -92,6 +92,7 @@ term = { | term_ctrl_type | term_apply_full | term_apply + | term_non_linear } term_wildcard = { "_" } @@ -114,3 +115,4 @@ term_adt = { "(" ~ "adt" ~ term ~ ")" } term_func_type = { "(" ~ "fn" ~ term ~ term ~ term ~ ")" } term_ctrl = { "(" ~ "ctrl" ~ term ~ ")" } term_ctrl_type = { "ctrl" } +term_non_linear = { "(" ~ "nonlinear" ~ term ~ ")" } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index fa486454b..370dbeac0 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -7,7 +7,7 @@ use thiserror::Error; use crate::v0::{ AliasDecl, ConstructorDecl, FuncDecl, GlobalRef, LinkRef, LocalRef, MetaItem, Module, Node, - NodeId, Operation, OperationDecl, Param, Region, RegionId, RegionKind, Term, TermId, + NodeId, Operation, OperationDecl, Param, ParamSort, Region, RegionId, RegionKind, Term, TermId, }; mod pest_parser { @@ -209,6 +209,11 @@ impl<'a> ParseContext<'a> { Term::Control { values } } + Rule::term_non_linear => { + let term = self.parse_term(inner.next().unwrap())?; + Term::NonLinearConstraint { term } + } + r => unreachable!("term: {:?}", r), }; @@ -544,6 +549,7 @@ impl<'a> ParseContext<'a> { let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; + let constraints = self.parse_constraints(&mut inner)?; let inputs = self.parse_term(inner.next().unwrap())?; let outputs = self.parse_term(inner.next().unwrap())?; @@ -559,6 +565,7 @@ impl<'a> ParseContext<'a> { Ok(self.bump.alloc(FuncDecl { name, params, + constraints, signature: func, })) } @@ -584,11 +591,13 @@ impl<'a> ParseContext<'a> { let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; + let constraints = self.parse_constraints(&mut inner)?; let r#type = self.parse_term(inner.next().unwrap())?; Ok(self.bump.alloc(ConstructorDecl { name, params, + constraints, r#type, })) } @@ -599,11 +608,13 @@ impl<'a> ParseContext<'a> { let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; + let constraints = self.parse_constraints(&mut inner)?; let r#type = self.parse_term(inner.next().unwrap())?; Ok(self.bump.alloc(OperationDecl { name, params, + constraints, r#type, })) } @@ -619,18 +630,21 @@ impl<'a> ParseContext<'a> { let mut inner = param.into_inner(); let name = &inner.next().unwrap().as_str()[1..]; let r#type = self.parse_term(inner.next().unwrap())?; - Param::Implicit { name, r#type } + Param { + name, + r#type, + sort: ParamSort::Implicit, + } } Rule::param_explicit => { let mut inner = param.into_inner(); let name = &inner.next().unwrap().as_str()[1..]; let r#type = self.parse_term(inner.next().unwrap())?; - Param::Explicit { name, r#type } - } - Rule::param_constraint => { - let mut inner = param.into_inner(); - let constraint = self.parse_term(inner.next().unwrap())?; - Param::Constraint { constraint } + Param { + name, + r#type, + sort: ParamSort::Explicit, + } } _ => unreachable!(), }; @@ -641,6 +655,17 @@ impl<'a> ParseContext<'a> { Ok(self.bump.alloc_slice_copy(¶ms)) } + fn parse_constraints(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [TermId]> { + let mut constraints = Vec::new(); + + for pair in filter_rule(pairs, Rule::where_clause) { + let constraint = self.parse_term(pair.into_inner().next().unwrap())?; + constraints.push(constraint); + } + + Ok(self.bump.alloc_slice_copy(&constraints)) + } + fn parse_signature(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult> { let Some(Rule::signature) = pairs.peek().map(|p| p.as_rule()) else { return Ok(None); diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index 01b9d7195..512f6d1e4 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -2,8 +2,8 @@ use pretty::{Arena, DocAllocator, RefDoc}; use std::borrow::Cow; use crate::v0::{ - GlobalRef, LinkRef, LocalRef, MetaItem, ModelError, Module, NodeId, Operation, Param, RegionId, - RegionKind, Term, TermId, + GlobalRef, LinkRef, LocalRef, MetaItem, ModelError, Module, NodeId, Operation, Param, + ParamSort, RegionId, RegionKind, Term, TermId, }; type PrintError = ModelError; @@ -122,15 +122,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { f: impl FnOnce(&mut Self) -> PrintResult, ) -> PrintResult { let locals = std::mem::take(&mut self.locals); - - for param in params { - match param { - Param::Implicit { name, .. } => self.locals.push(name), - Param::Explicit { name, .. } => self.locals.push(name), - Param::Constraint { .. } => {} - } - } - + self.locals.extend(params.iter().map(|param| param.name)); let result = f(self); self.locals = locals; result @@ -178,9 +170,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; match self.module.get_term(decl.signature) { Some(Term::FuncType { @@ -208,9 +199,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; match self.module.get_term(decl.signature) { Some(Term::FuncType { @@ -303,9 +293,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; this.print_term(decl.r#type)?; this.print_term(*value)?; @@ -318,9 +306,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; this.print_term(decl.r#type)?; this.print_meta(node_data.meta)?; @@ -333,9 +319,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; this.print_term(decl.r#type)?; this.print_meta(node_data.meta)?; @@ -348,9 +333,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; this.print_term(decl.r#type)?; this.print_meta(node_data.meta)?; @@ -384,10 +368,9 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { } fn print_regions(&mut self, regions: &'a [RegionId]) -> PrintResult<()> { - for region in regions { - self.print_region(*region)?; - } - Ok(()) + regions + .iter() + .try_for_each(|region| self.print_region(*region)) } fn print_region(&mut self, region: RegionId) -> PrintResult<()> { @@ -422,11 +405,10 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { .get_region(region) .ok_or(PrintError::RegionNotFound(region))?; - for node_id in region_data.children { - self.print_node(*node_id)?; - } - - Ok(()) + region_data + .children + .iter() + .try_for_each(|node_id| self.print_node(*node_id)) } fn print_port_lists( @@ -460,25 +442,33 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { } } + fn print_params(&mut self, params: &'a [Param<'a>]) -> PrintResult<()> { + params.iter().try_for_each(|param| self.print_param(*param)) + } + fn print_param(&mut self, param: Param<'a>) -> PrintResult<()> { - self.print_parens(|this| match param { - Param::Implicit { name, r#type } => { - this.print_text("forall"); - this.print_text(format!("?{}", name)); - this.print_term(r#type) - } - Param::Explicit { name, r#type } => { - this.print_text("param"); - this.print_text(format!("?{}", name)); - this.print_term(r#type) - } - Param::Constraint { constraint } => { - this.print_text("where"); - this.print_term(constraint) - } + self.print_parens(|this| { + match param.sort { + ParamSort::Implicit => this.print_text("forall"), + ParamSort::Explicit => this.print_text("param"), + }; + + this.print_text(format!("?{}", param.name)); + this.print_term(param.r#type) }) } + fn print_constraints(&mut self, terms: &'a [TermId]) -> PrintResult<()> { + for term in terms { + self.print_parens(|this| { + this.print_text("where"); + this.print_term(*term) + })?; + } + + Ok(()) + } + fn print_term(&mut self, term_id: TermId) -> PrintResult<()> { let term_data = self .module @@ -598,6 +588,10 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { self.print_text("ctrl"); Ok(()) } + Term::NonLinearConstraint { term } => self.print_parens(|this| { + this.print_text("nonlinear"); + this.print_term(*term) + }), } } diff --git a/hugr-model/tests/binary.rs b/hugr-model/tests/binary.rs index 043061677..93955fe6e 100644 --- a/hugr-model/tests/binary.rs +++ b/hugr-model/tests/binary.rs @@ -51,3 +51,8 @@ pub fn test_params() { pub fn test_decl_exts() { binary_roundtrip(include_str!("fixtures/model-decl-exts.edn")); } + +#[test] +pub fn test_constraints() { + binary_roundtrip(include_str!("fixtures/model-constraints.edn")); +} diff --git a/hugr-model/tests/fixtures/model-constraints.edn b/hugr-model/tests/fixtures/model-constraints.edn new file mode 100644 index 000000000..5db6b9886 --- /dev/null +++ b/hugr-model/tests/fixtures/model-constraints.edn @@ -0,0 +1,13 @@ +(hugr 0) + +(declare-func array.replicate + (forall ?t type) + (forall ?n nat) + (where (nonlinear ?t)) + [?t] [(@ array.Array ?t ?n)] + (ext)) + +(declare-func array.copy + (forall ?t type) + (where (nonlinear ?t)) + [(@ array.Array ?t)] [(@ array.Array ?t) (@ array.Array ?t)] (ext)) diff --git a/hugr-passes/src/half_node.rs b/hugr-passes/src/half_node.rs index 4bacac548..336d8992c 100644 --- a/hugr-passes/src/half_node.rs +++ b/hugr-passes/src/half_node.rs @@ -64,7 +64,6 @@ impl> HalfNodeView { } impl> CfgNodeMap for HalfNodeView { - type Iterator<'c> = as IntoIterator>::IntoIter where Self: 'c; fn entry_node(&self) -> HalfNode { HalfNode::N(self.entry) } @@ -72,7 +71,7 @@ impl> CfgNodeMap for HalfNodeView assert!(self.bb_succs(self.exit).count() == 0); HalfNode::N(self.exit) } - fn predecessors(&self, h: HalfNode) -> Self::Iterator<'_> { + fn predecessors(&self, h: HalfNode) -> impl Iterator { let mut ps = Vec::new(); match h { HalfNode::N(ni) => ps.extend(self.bb_preds(ni).map(|n| self.resolve_out(n))), @@ -83,7 +82,7 @@ impl> CfgNodeMap for HalfNodeView } ps.into_iter() } - fn successors(&self, n: HalfNode) -> Self::Iterator<'_> { + fn successors(&self, n: HalfNode) -> impl Iterator { let mut succs = Vec::new(); match n { HalfNode::N(ni) if self.is_multi_node(ni) => succs.push(HalfNode::X(ni)), diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index 51fd07d9b..249eed6b4 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -276,7 +276,7 @@ mod test { .nodes() .filter(|n| h.get_optype(*n).cast::().is_some()); let (entry_nop, expected_backedge_target) = if self_loop { - assert_eq!(h.children(r).len(), 2); + assert_eq!(h.children(r).count(), 2); (nops.exactly_one().ok().unwrap(), entry) } else { let [_, _, no_b2] = h.children(r).collect::>().try_into().unwrap(); diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs index 31f91b9e4..fa7106432 100644 --- a/hugr-passes/src/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -67,15 +67,10 @@ pub trait CfgNodeMap { fn entry_node(&self) -> T; /// The unique exit node of the CFG. The only node to have no successors. fn exit_node(&self) -> T; - /// Allows the trait implementor to define a type of iterator it will return from - /// `successors` and `predecessors`. - type Iterator<'c>: Iterator - where - Self: 'c; /// Returns an iterator over the successors of the specified basic block. - fn successors(&self, node: T) -> Self::Iterator<'_>; + fn successors(&self, node: T) -> impl Iterator; /// Returns an iterator over the predecessors of the specified basic block. - fn predecessors(&self, node: T) -> Self::Iterator<'_>; + fn predecessors(&self, node: T) -> impl Iterator; } /// Extension of [CfgNodeMap] to that can perform (mutable/destructive) @@ -242,15 +237,11 @@ impl CfgNodeMap for IdentityCfgMap { self.exit } - type Iterator<'c> = ::Neighbours<'c> - where - Self: 'c; - - fn successors(&self, node: Node) -> Self::Iterator<'_> { + fn successors(&self, node: Node) -> impl Iterator { self.h.neighbours(node, Direction::Outgoing) } - fn predecessors(&self, node: Node) -> Self::Iterator<'_> { + fn predecessors(&self, node: Node) -> impl Iterator { self.h.neighbours(node, Direction::Incoming) } } @@ -731,9 +722,9 @@ pub(crate) mod test { // | \-> right -/ | // \---<---<---<---<---<---<---<---<---/ // split is unique successor of head - let split = h.output_neighbours(head).exactly_one().unwrap(); + let split = h.output_neighbours(head).exactly_one().ok().unwrap(); // merge is unique predecessor of tail - let merge = h.input_neighbours(tail).exactly_one().unwrap(); + let merge = h.input_neighbours(tail).exactly_one().ok().unwrap(); // There's no need to use a view of a region here but we do so just to check // that we *can* (as we'll need to for "real" module Hugr's)