Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HugrView: validate nodes, and remove Base #523

Merged
merged 11 commits into from
Sep 13, 2023
2 changes: 1 addition & 1 deletion src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ pub(crate) mod test {
let (split, merge, head, tail) = (split.node(), merge.node(), head.node(), tail.node());
// There's no need to use a FlatRegionView here but we do so just to check
// that we *can* (as we'll need to for "real" module Hugr's).
let v: SiblingGraph = SiblingGraph::new(&h, h.root());
let v: SiblingGraph = SiblingGraph::try_new(&h, h.root()).unwrap();
let edge_classes = EdgeClassifier::get_edge_classes(&SimpleCfgView::new(&v));
let [&left, &right] = edge_classes
.keys()
Expand Down
8 changes: 7 additions & 1 deletion src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub use self::views::HugrView;
use crate::extension::{
infer_extensions, ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError,
};
use crate::ops::{OpTag, OpTrait, OpType};
use crate::ops::{OpTag, OpTrait, OpType, DEFAULT_OPTYPE};
use crate::types::{FunctionType, Signature};

use delegate::delegate;
Expand Down Expand Up @@ -64,6 +64,12 @@ pub struct NodeType {
input_extensions: Option<ExtensionSet>,
}

/// The default NodeType, with open extensions
pub const DEFAULT_NODETYPE: NodeType = NodeType {
op: DEFAULT_OPTYPE,
input_extensions: None, // Default for any Option
};

impl NodeType {
/// Create a new optype with some ExtensionSet
pub fn new(op: impl Into<OpType>, input_extensions: impl Into<Option<ExtensionSet>>) -> Self {
Expand Down
24 changes: 0 additions & 24 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,30 +322,6 @@ pub(crate) mod sealed {
/// Returns the Hugr at the base of a chain of views.
fn hugr_mut(&mut self) -> &mut Hugr;

/// Validates that a node is valid in the graph.
///
/// Returns a [`HugrError::InvalidNode`] otherwise.
#[inline]
fn valid_node(&self, node: Node) -> Result<(), HugrError> {
match self.contains_node(node) {
true => Ok(()),
false => Err(HugrError::InvalidNode(node)),
}
}

/// Validates that a node is a valid root descendant in the graph.
///
/// To include the root node use [`HugrMutInternals::valid_node`] instead.
///
/// Returns a [`HugrError::InvalidNode`] otherwise.
#[inline]
fn valid_non_root(&self, node: Node) -> Result<(), HugrError> {
match self.root() == node {
true => Err(HugrError::InvalidNode(node)),
false => self.valid_node(node),
}
}

/// Set the number of ports on a node. This may invalidate the node's `PortIndex`.
fn set_num_ports(&mut self, node: Node, incoming: usize, outgoing: usize) {
self.valid_node(node).unwrap_or_else(|e| panic!("{}", e));
Expand Down
4 changes: 2 additions & 2 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
/// The results of this computation should be cached in `self.dominators`.
/// We don't do it here to avoid mutable borrows.
fn compute_dominator(&self, parent: Node) -> Dominators<Node> {
let region: SiblingGraph = SiblingGraph::new(self.hugr, parent);
let region: SiblingGraph = SiblingGraph::try_new(self.hugr, parent).unwrap();
let entry_node = self.hugr.children(parent).next().unwrap();
dominators::simple_fast(&region.as_petgraph(), entry_node)
}
Expand Down Expand Up @@ -374,7 +374,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
return Ok(());
};

let region: SiblingGraph = SiblingGraph::new(self.hugr, parent);
let region: SiblingGraph = SiblingGraph::try_new(self.hugr, parent).unwrap();
let postorder = Topo::new(&region.as_petgraph());
let nodes_visited = postorder
.iter(&region.as_petgraph())
Expand Down
86 changes: 56 additions & 30 deletions src/hugr/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use itertools::{Itertools, MapInto};
use portgraph::dot::{DotFormat, EdgeStyle, NodeStyle, PortStyle};
use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView};

use super::{Hugr, NodeMetadata, NodeType};
use super::{Hugr, HugrError, NodeMetadata, NodeType, DEFAULT_NODETYPE};
use crate::ops::handle::NodeHandle;
use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpType, DFG};
use crate::types::{EdgeKind, FunctionType};
Expand Down Expand Up @@ -80,17 +80,63 @@ pub trait HugrView: sealed::HugrInternals {
/// Returns whether the node exists.
fn contains_node(&self, node: Node) -> bool;

/// Validates that a node is valid in the graph.
///
/// Returns a [`HugrError::InvalidNode`] otherwise.
#[inline]
fn valid_node(&self, node: Node) -> Result<(), HugrError> {
match self.contains_node(node) {
true => Ok(()),
false => Err(HugrError::InvalidNode(node)),
}
}

/// Validates that a node is a valid root descendant in the graph.
///
/// To include the root node use [`HugrView::valid_node`] instead.
///
/// Returns a [`HugrError::InvalidNode`] otherwise.
#[inline]
fn valid_non_root(&self, node: Node) -> Result<(), HugrError> {
match self.root() == node {
true => Err(HugrError::InvalidNode(node)),
false => self.valid_node(node),
}
}

/// Returns the parent of a node.
fn get_parent(&self, node: Node) -> Option<Node>;
#[inline]
fn get_parent(&self, node: Node) -> Option<Node> {
self.valid_non_root(node).ok()?;
self.base_hugr()
.hierarchy
.parent(node.index)
.map(Into::into)
}

/// Returns the operation type of a node.
fn get_optype(&self, node: Node) -> &OpType;
#[inline]
fn get_optype(&self, node: Node) -> &OpType {
&self.get_nodetype(node).op
}

/// Returns the type of a node.
fn get_nodetype(&self, node: Node) -> &NodeType;
#[inline]
fn get_nodetype(&self, node: Node) -> &NodeType {
match self.contains_node(node) {
true => self.base_hugr().op_types.get(node.index),
false => &DEFAULT_NODETYPE,
}
}

/// Returns the metadata associated with a node.
fn get_metadata(&self, node: Node) -> &NodeMetadata;
#[inline]
fn get_metadata(&self, node: Node) -> &NodeMetadata {
match self.contains_node(node) {
true => self.base_hugr().metadata.get(node.index),
false => &NodeMetadata::Null,
}
}

/// Returns the number of nodes in the hugr.
fn node_count(&self) -> usize;
Expand Down Expand Up @@ -249,12 +295,12 @@ pub trait HugrView: sealed::HugrInternals {
}

/// A common trait for views of a HUGR hierarchical subgraph.
pub trait HierarchyView<'a>: HugrView {
/// The base from which the subgraph is derived.
type Base;

pub trait HierarchyView<'a>: HugrView + Sized {
/// Create a hierarchical view of a HUGR given a root node.
fn new(hugr: &'a Self::Base, root: Node) -> Self;
///
/// # Errors
/// Returns [`HugrError::InvalidNode`] if the root isn't a node of the required [OpTag]
fn try_new(hugr: &'a impl HugrView, root: Node) -> Result<Self, HugrError>;
}

impl<T> HugrView for T
Expand Down Expand Up @@ -287,21 +333,6 @@ where
self.as_ref().graph.contains_node(node.index)
}

#[inline]
fn get_parent(&self, node: Node) -> Option<Node> {
self.as_ref().hierarchy.parent(node.index).map(Into::into)
}

#[inline]
fn get_optype(&self, node: Node) -> &OpType {
&self.as_ref().op_types.get(node.index).op
}

#[inline]
fn get_nodetype(&self, node: Node) -> &NodeType {
self.as_ref().op_types.get(node.index)
}

#[inline]
fn node_count(&self) -> usize {
self.as_ref().graph.node_count()
Expand Down Expand Up @@ -386,11 +417,6 @@ where
None
}
}

#[inline]
fn get_metadata(&self, node: Node) -> &NodeMetadata {
self.as_ref().metadata.get(node.index)
}
}

pub(crate) mod sealed {
Expand Down
80 changes: 19 additions & 61 deletions src/hugr/views/descendants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx};
use itertools::{Itertools, MapInto};
use portgraph::{LinkView, MultiPortGraph, PortIndex, PortView};

use crate::hugr::HugrError;
use crate::ops::handle::NodeHandle;
use crate::ops::OpTrait;
use crate::{hugr::NodeType, hugr::OpType, Direction, Hugr, Node, Port};
use crate::{Direction, Hugr, Node, Port};

use super::{sealed::HugrInternals, HierarchyView, HugrView, NodeMetadata};
use super::{sealed::HugrInternals, HierarchyView, HugrView};

type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>;

Expand All @@ -25,37 +26,24 @@ type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>;
/// used interchangeably with [`SiblingGraph`].
///
/// [`SiblingGraph`]: super::SiblingGraph
pub struct DescendantsGraph<'g, Root = Node, Base = Hugr>
where
Base: HugrInternals,
{
#[derive(Clone)]
pub struct DescendantsGraph<'g, Root = Node> {
/// The chosen root node.
root: Node,

/// The graph encoding the adjacency structure of the HUGR.
graph: RegionGraph<'g>,

/// The node hierarchy.
hugr: &'g Base,
hugr: &'g Hugr,

/// The operation handle of the root node.
_phantom: std::marker::PhantomData<Root>,
}

impl<'g, Root, Base: Clone> Clone for DescendantsGraph<'g, Root, Base>
where
Root: NodeHandle,
Base: HugrInternals + HugrView,
{
fn clone(&self) -> Self {
DescendantsGraph::new(self.hugr, self.root)
}
}

impl<'g, Root, Base> HugrView for DescendantsGraph<'g, Root, Base>
impl<'g, Root> HugrView for DescendantsGraph<'g, Root>
where
Root: NodeHandle,
Base: HugrInternals + HugrView,
{
type RootHandle = Root;

Expand Down Expand Up @@ -94,29 +82,6 @@ where
self.graph.contains_node(node.index)
}

#[inline]
fn get_parent(&self, node: Node) -> Option<Node> {
self.hugr
.get_parent(node)
.filter(|&parent| self.graph.contains_node(parent.index))
.map(Into::into)
}

#[inline]
fn get_optype(&self, node: Node) -> &OpType {
self.hugr.get_optype(node)
}

#[inline]
fn get_nodetype(&self, node: Node) -> &NodeType {
self.hugr.get_nodetype(node)
}

#[inline]
fn get_metadata(&self, node: Node) -> &NodeMetadata {
self.hugr.get_metadata(node)
}

#[inline]
fn node_count(&self) -> usize {
self.graph.node_count()
Expand Down Expand Up @@ -196,36 +161,29 @@ where
}
}

impl<'a, Root, Base> HierarchyView<'a> for DescendantsGraph<'a, Root, Base>
impl<'a, Root> HierarchyView<'a> for DescendantsGraph<'a, Root>
where
Root: NodeHandle,
Base: HugrView,
{
type Base = Base;

fn new(hugr: &'a Base, root: Node) -> Self {
fn try_new(hugr: &'a impl HugrView, root: Node) -> Result<Self, HugrError> {
hugr.valid_node(root)?;
let root_tag = hugr.get_optype(root).tag();
if !Root::TAG.is_superset(root_tag) {
// TODO: Return an error
panic!("Root node must have the correct operation type tag.")
return Err(HugrError::InvalidNode(root));
}
Self {
let hugr = hugr.base_hugr();
Ok(Self {
root,
graph: RegionGraph::new_region(
&hugr.base_hugr().graph,
&hugr.base_hugr().hierarchy,
root.index,
),
graph: RegionGraph::new_region(&hugr.graph, &hugr.hierarchy, root.index),
hugr,
_phantom: std::marker::PhantomData,
}
})
}
}

impl<'g, Root, Base> super::sealed::HugrInternals for DescendantsGraph<'g, Root, Base>
impl<'g, Root> super::sealed::HugrInternals for DescendantsGraph<'g, Root>
where
Root: NodeHandle,
Base: HugrInternals,
{
type Portgraph<'p> = &'p RegionGraph<'g> where Self: 'p;

Expand All @@ -236,7 +194,7 @@ where

#[inline]
fn base_hugr(&self) -> &Hugr {
self.hugr.base_hugr()
self.hugr
}

#[inline]
Expand Down Expand Up @@ -299,7 +257,7 @@ pub(super) mod test {
fn full_region() -> Result<(), Box<dyn std::error::Error>> {
let (hugr, def, inner) = make_module_hgr()?;

let region: DescendantsGraph = DescendantsGraph::new(&hugr, def);
let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?;

assert_eq!(region.node_count(), 7);
assert!(region.nodes().all(|n| n == def
Expand All @@ -311,7 +269,7 @@ pub(super) mod test {
region.get_function_type(),
Some(&FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]))
);
let inner_region: DescendantsGraph = DescendantsGraph::new(&hugr, inner);
let inner_region: DescendantsGraph = DescendantsGraph::try_new(&hugr, inner)?;
assert_eq!(
inner_region.get_function_type(),
Some(&FunctionType::new(type_row![NAT], type_row![NAT]))
Expand Down
Loading