-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: Separate hierarchy views (#500)
* Move views/hierarchy/petgraph.rs to views/petgraph.rs * Move views/sibling.rs to views/sibling_subgraph.rs. I think the fact that it's a(n arbitrary) subgraph of a sibling subgraph is important, and more significant than that they are siblings (the USP, if you like - we have sibling views elsewhere), so emphasize this. * Move HierarchyView from views/hierarchy to views * Split hierarchy.rs into sibling.rs and descendants.rs
- Loading branch information
Showing
6 changed files
with
1,348 additions
and
1,332 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,316 @@ | ||
//! DescendantsGraph: view onto the subgraph of the HUGR starting from a root | ||
//! (all descendants at all depths). | ||
use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx}; | ||
use itertools::{Itertools, MapInto}; | ||
use portgraph::{LinkView, MultiPortGraph, PortIndex, PortView}; | ||
|
||
use crate::ops::handle::NodeHandle; | ||
use crate::ops::OpTrait; | ||
use crate::{hugr::NodeType, hugr::OpType, Direction, Hugr, Node, Port}; | ||
|
||
use super::{sealed::HugrInternals, HierarchyView, HugrView, NodeMetadata}; | ||
|
||
type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>; | ||
|
||
/// View of a HUGR descendants graph. | ||
/// | ||
/// Includes the root node (which uniquely has no parent) and all its descendants. | ||
/// | ||
/// See [`SiblingGraph`] for a view that includes only the root and | ||
/// its immediate children. Prefer using [`SiblingGraph`] when possible, | ||
/// as it is more efficient. | ||
/// | ||
/// Implements the [`HierarchyView`] trait, as well as [`HugrView`] and petgraph's | ||
/// _visit_ traits, so can be used interchangeably with [`SiblingGraph`]. | ||
/// | ||
/// [`SiblingGraph`]: super::SiblingGraph | ||
pub struct DescendantsGraph<'g, Root = Node, Base = Hugr> | ||
where | ||
Base: HugrInternals, | ||
{ | ||
/// The chosen root node. | ||
root: Node, | ||
|
||
/// The graph encoding the adjacency structure of the HUGR. | ||
graph: RegionGraph<'g>, | ||
|
||
/// The node hierarchy. | ||
hugr: &'g Base, | ||
|
||
/// The operation handle of the root node. | ||
_phantom: std::marker::PhantomData<Root>, | ||
} | ||
|
||
impl<'g, Root, Base: Clone> Clone for DescendantsGraph<'g, Root, Base> | ||
where | ||
Root: NodeHandle, | ||
Base: HugrInternals + HugrView, | ||
{ | ||
fn clone(&self) -> Self { | ||
DescendantsGraph::new(self.hugr, self.root) | ||
} | ||
} | ||
|
||
impl<'g, Root, Base> HugrView for DescendantsGraph<'g, Root, Base> | ||
where | ||
Root: NodeHandle, | ||
Base: HugrInternals + HugrView, | ||
{ | ||
type RootHandle = Root; | ||
|
||
type Nodes<'a> = MapInto<<RegionGraph<'g> as PortView>::Nodes<'a>, Node> | ||
where | ||
Self: 'a; | ||
|
||
type NodePorts<'a> = MapInto<<RegionGraph<'g> as PortView>::NodePortOffsets<'a>, Port> | ||
where | ||
Self: 'a; | ||
|
||
type Children<'a> = MapInto<portgraph::hierarchy::Children<'a>, Node> | ||
where | ||
Self: 'a; | ||
|
||
type Neighbours<'a> = MapInto<<RegionGraph<'g> as LinkView>::Neighbours<'a>, Node> | ||
where | ||
Self: 'a; | ||
|
||
type PortLinks<'a> = MapWithCtx< | ||
<RegionGraph<'g> as LinkView>::PortLinks<'a>, | ||
&'a Self, | ||
(Node, Port), | ||
> where | ||
Self: 'a; | ||
|
||
type NodeConnections<'a> = MapWithCtx< | ||
<RegionGraph<'g> as LinkView>::NodeConnections<'a>, | ||
&'a Self, | ||
[Port; 2], | ||
> where | ||
Self: 'a; | ||
|
||
#[inline] | ||
fn contains_node(&self, node: Node) -> bool { | ||
self.graph.contains_node(node.index) | ||
} | ||
|
||
#[inline] | ||
fn get_parent(&self, node: Node) -> Option<Node> { | ||
self.hugr | ||
.get_parent(node) | ||
.filter(|&parent| self.graph.contains_node(parent.index)) | ||
.map(Into::into) | ||
} | ||
|
||
#[inline] | ||
fn get_optype(&self, node: Node) -> &OpType { | ||
self.hugr.get_optype(node) | ||
} | ||
|
||
#[inline] | ||
fn get_nodetype(&self, node: Node) -> &NodeType { | ||
self.hugr.get_nodetype(node) | ||
} | ||
|
||
#[inline] | ||
fn get_metadata(&self, node: Node) -> &NodeMetadata { | ||
self.hugr.get_metadata(node) | ||
} | ||
|
||
#[inline] | ||
fn node_count(&self) -> usize { | ||
self.graph.node_count() | ||
} | ||
|
||
#[inline] | ||
fn edge_count(&self) -> usize { | ||
self.graph.link_count() | ||
} | ||
|
||
#[inline] | ||
fn nodes(&self) -> Self::Nodes<'_> { | ||
self.graph.nodes_iter().map_into() | ||
} | ||
|
||
#[inline] | ||
fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_> { | ||
self.graph.port_offsets(node.index, dir).map_into() | ||
} | ||
|
||
#[inline] | ||
fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_> { | ||
self.graph.all_port_offsets(node.index).map_into() | ||
} | ||
|
||
fn linked_ports(&self, node: Node, port: Port) -> Self::PortLinks<'_> { | ||
let port = self.graph.port_index(node.index, port.offset).unwrap(); | ||
self.graph | ||
.port_links(port) | ||
.with_context(self) | ||
.map_with_context(|(_, link), region| { | ||
let port: PortIndex = link.into(); | ||
let node = region.graph.port_node(port).unwrap(); | ||
let offset = region.graph.port_offset(port).unwrap(); | ||
(node.into(), offset.into()) | ||
}) | ||
} | ||
|
||
fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> { | ||
self.graph | ||
.get_connections(node.index, other.index) | ||
.with_context(self) | ||
.map_with_context(|(p1, p2), hugr| { | ||
[p1, p2].map(|link| { | ||
let offset = hugr.graph.port_offset(link).unwrap(); | ||
offset.into() | ||
}) | ||
}) | ||
} | ||
|
||
#[inline] | ||
fn num_ports(&self, node: Node, dir: Direction) -> usize { | ||
self.graph.num_ports(node.index, dir) | ||
} | ||
|
||
#[inline] | ||
fn children(&self, node: Node) -> Self::Children<'_> { | ||
match self.graph.contains_node(node.index) { | ||
true => self.base_hugr().hierarchy.children(node.index).map_into(), | ||
false => portgraph::hierarchy::Children::default().map_into(), | ||
} | ||
} | ||
|
||
#[inline] | ||
fn neighbours(&self, node: Node, dir: Direction) -> Self::Neighbours<'_> { | ||
self.graph.neighbours(node.index, dir).map_into() | ||
} | ||
|
||
#[inline] | ||
fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_> { | ||
self.graph.all_neighbours(node.index).map_into() | ||
} | ||
|
||
#[inline] | ||
fn get_io(&self, node: Node) -> Option<[Node; 2]> { | ||
self.base_hugr().get_io(node) | ||
} | ||
|
||
fn get_function_type(&self) -> Option<&crate::types::FunctionType> { | ||
self.base_hugr().get_function_type() | ||
} | ||
} | ||
|
||
impl<'a, Root, Base> HierarchyView<'a> for DescendantsGraph<'a, Root, Base> | ||
where | ||
Root: NodeHandle, | ||
Base: HugrView, | ||
{ | ||
type Base = Base; | ||
|
||
fn new(hugr: &'a Base, root: Node) -> Self { | ||
let root_tag = hugr.get_optype(root).tag(); | ||
if !Root::TAG.is_superset(root_tag) { | ||
// TODO: Return an error | ||
panic!("Root node must have the correct operation type tag.") | ||
} | ||
Self { | ||
root, | ||
graph: RegionGraph::new_region( | ||
&hugr.base_hugr().graph, | ||
&hugr.base_hugr().hierarchy, | ||
root.index, | ||
), | ||
hugr, | ||
_phantom: std::marker::PhantomData, | ||
} | ||
} | ||
} | ||
|
||
impl<'g, Root, Base> super::sealed::HugrInternals for DescendantsGraph<'g, Root, Base> | ||
where | ||
Root: NodeHandle, | ||
Base: HugrInternals, | ||
{ | ||
type Portgraph<'p> = &'p RegionGraph<'g> where Self: 'p; | ||
|
||
#[inline] | ||
fn portgraph(&self) -> Self::Portgraph<'_> { | ||
&self.graph | ||
} | ||
|
||
#[inline] | ||
fn base_hugr(&self) -> &Hugr { | ||
self.hugr.base_hugr() | ||
} | ||
|
||
#[inline] | ||
fn root_node(&self) -> Node { | ||
self.root | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
pub(super) mod test { | ||
use crate::{ | ||
builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, | ||
ops::handle::NodeHandle, | ||
std_extensions::quantum::test::h_gate, | ||
type_row, | ||
types::{FunctionType, Type}, | ||
}; | ||
|
||
use super::*; | ||
|
||
const NAT: Type = crate::extension::prelude::USIZE_T; | ||
const QB: Type = crate::extension::prelude::QB_T; | ||
|
||
/// Make a module hugr with a fn definition containing an inner dfg node. | ||
/// | ||
/// Returns the hugr, the fn node id, and the nested dgf node id. | ||
pub(in crate::hugr::views) fn make_module_hgr( | ||
) -> Result<(Hugr, Node, Node), Box<dyn std::error::Error>> { | ||
let mut module_builder = ModuleBuilder::new(); | ||
|
||
let (f_id, inner_id) = { | ||
let mut func_builder = module_builder.define_function( | ||
"main", | ||
FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).pure(), | ||
)?; | ||
|
||
let [int, qb] = func_builder.input_wires_arr(); | ||
|
||
let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb])?; | ||
|
||
let inner_id = { | ||
let inner_builder = func_builder.dfg_builder( | ||
FunctionType::new(type_row![NAT], type_row![NAT]), | ||
None, | ||
[int], | ||
)?; | ||
let w = inner_builder.input_wires(); | ||
inner_builder.finish_with_outputs(w) | ||
}?; | ||
|
||
let f_id = | ||
func_builder.finish_with_outputs(inner_id.outputs().chain(q_out.outputs()))?; | ||
(f_id, inner_id) | ||
}; | ||
let hugr = module_builder.finish_prelude_hugr()?; | ||
Ok((hugr, f_id.handle().node(), inner_id.handle().node())) | ||
} | ||
|
||
#[test] | ||
fn full_region() -> Result<(), Box<dyn std::error::Error>> { | ||
let (hugr, def, inner) = make_module_hgr()?; | ||
|
||
let region: DescendantsGraph = DescendantsGraph::new(&hugr, def); | ||
|
||
assert_eq!(region.node_count(), 7); | ||
assert!(region.nodes().all(|n| n == def | ||
|| hugr.get_parent(n) == Some(def) | ||
|| hugr.get_parent(n) == Some(inner))); | ||
assert_eq!(region.children(inner).count(), 2); | ||
|
||
Ok(()) | ||
} | ||
} |
Oops, something went wrong.