Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HugrView: validate nodes, and remove Base #523

Merged
merged 11 commits into from
Sep 13, 2023
2 changes: 1 addition & 1 deletion src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ pub(crate) mod test {
let (split, merge, head, tail) = (split.node(), merge.node(), head.node(), tail.node());
// There's no need to use a FlatRegionView here but we do so just to check
// that we *can* (as we'll need to for "real" module Hugr's).
let v: SiblingGraph = SiblingGraph::new(&h, h.root());
let v: SiblingGraph = SiblingGraph::new(&h, h.root()).unwrap();
let edge_classes = EdgeClassifier::get_edge_classes(&SimpleCfgView::new(&v));
let [&left, &right] = edge_classes
.keys()
Expand Down
8 changes: 7 additions & 1 deletion src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub use self::views::HugrView;
use crate::extension::{
infer_extensions, ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError,
};
use crate::ops::{OpTag, OpTrait, OpType};
use crate::ops::{OpTag, OpTrait, OpType, DEFAULT_OPTYPE};
use crate::types::{FunctionType, Signature};

use delegate::delegate;
Expand Down Expand Up @@ -64,6 +64,12 @@ pub struct NodeType {
input_extensions: Option<ExtensionSet>,
}

/// The default NodeType, with open extensions
pub const DEFAULT_NODETYPE: NodeType = NodeType {
op: DEFAULT_OPTYPE,
input_extensions: None, // Default for any Option
};

impl NodeType {
/// Create a new optype with some ExtensionSet
pub fn new(op: impl Into<OpType>, input_extensions: impl Into<Option<ExtensionSet>>) -> Self {
Expand Down
24 changes: 0 additions & 24 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,30 +312,6 @@ pub(crate) mod sealed {
/// Returns the Hugr at the base of a chain of views.
fn hugr_mut(&mut self) -> &mut Hugr;

/// Validates that a node is valid in the graph.
///
/// Returns a [`HugrError::InvalidNode`] otherwise.
#[inline]
fn valid_node(&self, node: Node) -> Result<(), HugrError> {
match self.contains_node(node) {
true => Ok(()),
false => Err(HugrError::InvalidNode(node)),
}
}

/// Validates that a node is a valid root descendant in the graph.
///
/// To include the root node use [`HugrMutInternals::valid_node`] instead.
///
/// Returns a [`HugrError::InvalidNode`] otherwise.
#[inline]
fn valid_non_root(&self, node: Node) -> Result<(), HugrError> {
match self.root() == node {
true => Err(HugrError::InvalidNode(node)),
false => self.valid_node(node),
}
}

/// Add a node to the graph, with the default conversion from OpType to NodeType
fn add_op(&mut self, op: impl Into<OpType>) -> Node {
self.hugr_mut().add_op(op)
Expand Down
4 changes: 2 additions & 2 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
/// The results of this computation should be cached in `self.dominators`.
/// We don't do it here to avoid mutable borrows.
fn compute_dominator(&self, parent: Node) -> Dominators<Node> {
let region: SiblingGraph = SiblingGraph::new(self.hugr, parent);
let region: SiblingGraph = SiblingGraph::new(self.hugr, parent).unwrap();
let entry_node = self.hugr.children(parent).next().unwrap();
dominators::simple_fast(&region.as_petgraph(), entry_node)
}
Expand Down Expand Up @@ -374,7 +374,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
return Ok(());
};

let region: SiblingGraph = SiblingGraph::new(self.hugr, parent);
let region: SiblingGraph = SiblingGraph::new(self.hugr, parent).unwrap();
let postorder = Topo::new(&region.as_petgraph());
let nodes_visited = postorder
.iter(&region.as_petgraph())
Expand Down
90 changes: 60 additions & 30 deletions src/hugr/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use itertools::{Itertools, MapInto};
use portgraph::dot::{DotFormat, EdgeStyle, NodeStyle, PortStyle};
use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView};

use super::{Hugr, NodeMetadata, NodeType};
use super::{Hugr, HugrError, NodeMetadata, NodeType, DEFAULT_NODETYPE};
use crate::ops::handle::NodeHandle;
use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpType, DFG};
use crate::types::{EdgeKind, FunctionType};
Expand Down Expand Up @@ -80,17 +80,67 @@ pub trait HugrView: sealed::HugrInternals {
/// Returns whether the node exists.
fn contains_node(&self, node: Node) -> bool;

/// Validates that a node is valid in the graph.
///
/// Returns a [`HugrError::InvalidNode`] otherwise.
#[inline]
fn valid_node(&self, node: Node) -> Result<(), HugrError> {
match self.contains_node(node) {
true => Ok(()),
false => Err(HugrError::InvalidNode(node)),
}
}

/// Validates that a node is a valid root descendant in the graph.
///
/// To include the root node use [`HugrView::valid_node`] instead.
///
/// Returns a [`HugrError::InvalidNode`] otherwise.
#[inline]
fn valid_non_root(&self, node: Node) -> Result<(), HugrError> {
match self.root() == node {
true => Err(HugrError::InvalidNode(node)),
false => self.valid_node(node),
}
}

/// Returns the parent of a node.
fn get_parent(&self, node: Node) -> Option<Node>;
#[inline]
fn get_parent(&self, node: Node) -> Option<Node> {
self.valid_non_root(node).ok()?;
self.base_hugr()
.hierarchy
.parent(node.index)
.map(Into::into)
}

/// Returns the operation type of a node.
fn get_optype(&self, node: Node) -> &OpType;
#[inline]
fn get_optype(&self, node: Node) -> &OpType {
&self.get_nodetype(node).op
}

/// Returns the type of a node.
fn get_nodetype(&self, node: Node) -> &NodeType;
#[inline]
fn get_nodetype(&self, node: Node) -> &NodeType {
match self.contains_node(node) {
true => self.base_hugr().op_types.get(node.index),
false => &DEFAULT_NODETYPE,
}
}

/// Returns the metadata associated with a node.
fn get_metadata(&self, node: Node) -> &NodeMetadata;
#[inline]
fn get_metadata(&self, node: Node) -> &NodeMetadata {
// The other way to do it - exploit the UnmanagedDenseMap's get() returning &default
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is a bit cryptic to me, I'm not sure what it means.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry, this is where I'm not sure what to do. Hopefully @aborgna-q can give a shout, I'm wondering about a PR to add something to portgraph's UnmanagedDenseMap....

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alan's referring to the behaviour of the metadata container, which returns a reference to the default value if it doesn't know about the node index.

I'd avoid fiddling with portgraph assumptions when you can directly

        static DEFAULT_METADATA: NodeMetadata = NodeMetadata::Null;
        match self.contains_node(node) {
            true => self.base_hugr().metadata.get(node.index),
            false => &DEFAULT_METADATA,
        }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I can use NodeMetadata::Null, indeed.

So I guess this will do for now - I don't think this have to declare const DEFAULT_NODETYPE etc. is a good general solution (when I know an appropriately-lifetimed default value is sitting there in the map), but it's just about OK for the specific cases here. IOW I might come back to this in another PR later...

let md = &self.base_hugr().metadata;

let idx = match self.contains_node(node) {
true => node.index,
false => portgraph::NodeIndex::new(md.capacity() + 1),
};
md.get(idx)
}

/// Returns the number of nodes in the hugr.
fn node_count(&self) -> usize;
Expand Down Expand Up @@ -249,12 +299,12 @@ pub trait HugrView: sealed::HugrInternals {
}

/// A common trait for views of a HUGR hierarchical subgraph.
pub trait HierarchyView<'a>: HugrView {
/// The base from which the subgraph is derived.
type Base;

pub trait HierarchyView<'a>: HugrView + Sized {
/// Create a hierarchical view of a HUGR given a root node.
fn new(hugr: &'a Self::Base, root: Node) -> Self;
///
/// # Errors
/// Returns [`HugrError::InvalidNode`] if the root isn't a node of the required [OpTag]
fn new(hugr: &'a impl HugrView, root: Node) -> Result<Self, HugrError>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, naming questions that I am not the authority on but will still give my opinion -- use it as you wish. I try to prefix function names that can fail with try_, so here I would choose try_new.

@aborgna-q what is your preference? You're more of a rustacean than me.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this should be try_new

Copy link
Contributor Author

@acl-cqc acl-cqc Sep 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me, done. I wonder about a utility function new that calls try_new(...).unwrap() - even if that were #[cfg(test)]. I have not done the latter ATM but please shout if you find repeated try_new...unwrap sufficiently annoying to justify this idea!

}

impl<T> HugrView for T
Expand Down Expand Up @@ -287,21 +337,6 @@ where
self.as_ref().graph.contains_node(node.index)
}

#[inline]
fn get_parent(&self, node: Node) -> Option<Node> {
self.as_ref().hierarchy.parent(node.index).map(Into::into)
}

#[inline]
fn get_optype(&self, node: Node) -> &OpType {
&self.as_ref().op_types.get(node.index).op
}

#[inline]
fn get_nodetype(&self, node: Node) -> &NodeType {
self.as_ref().op_types.get(node.index)
}

#[inline]
fn node_count(&self) -> usize {
self.as_ref().graph.node_count()
Expand Down Expand Up @@ -386,11 +421,6 @@ where
None
}
}

#[inline]
fn get_metadata(&self, node: Node) -> &NodeMetadata {
self.as_ref().metadata.get(node.index)
}
}

pub(crate) mod sealed {
Expand Down
80 changes: 19 additions & 61 deletions src/hugr/views/descendants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx};
use itertools::{Itertools, MapInto};
use portgraph::{LinkView, MultiPortGraph, PortIndex, PortView};

use crate::hugr::HugrError;
use crate::ops::handle::NodeHandle;
use crate::ops::OpTrait;
use crate::{hugr::NodeType, hugr::OpType, Direction, Hugr, Node, Port};
use crate::{Direction, Hugr, Node, Port};

use super::{sealed::HugrInternals, HierarchyView, HugrView, NodeMetadata};
use super::{sealed::HugrInternals, HierarchyView, HugrView};

type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>;

Expand All @@ -25,37 +26,24 @@ type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>;
/// used interchangeably with [`SiblingGraph`].
///
/// [`SiblingGraph`]: super::SiblingGraph
pub struct DescendantsGraph<'g, Root = Node, Base = Hugr>
where
Base: HugrInternals,
{
#[derive(Clone)]
pub struct DescendantsGraph<'g, Root = Node> {
/// The chosen root node.
root: Node,

/// The graph encoding the adjacency structure of the HUGR.
graph: RegionGraph<'g>,

/// The node hierarchy.
hugr: &'g Base,
hugr: &'g Hugr,

/// The operation handle of the root node.
_phantom: std::marker::PhantomData<Root>,
}

impl<'g, Root, Base: Clone> Clone for DescendantsGraph<'g, Root, Base>
where
Root: NodeHandle,
Base: HugrInternals + HugrView,
{
fn clone(&self) -> Self {
DescendantsGraph::new(self.hugr, self.root)
}
}

impl<'g, Root, Base> HugrView for DescendantsGraph<'g, Root, Base>
impl<'g, Root> HugrView for DescendantsGraph<'g, Root>
where
Root: NodeHandle,
Base: HugrInternals + HugrView,
{
type RootHandle = Root;

Expand Down Expand Up @@ -94,29 +82,6 @@ where
self.graph.contains_node(node.index)
}

#[inline]
fn get_parent(&self, node: Node) -> Option<Node> {
self.hugr
.get_parent(node)
.filter(|&parent| self.graph.contains_node(parent.index))
.map(Into::into)
}

#[inline]
fn get_optype(&self, node: Node) -> &OpType {
self.hugr.get_optype(node)
}

#[inline]
fn get_nodetype(&self, node: Node) -> &NodeType {
self.hugr.get_nodetype(node)
}

#[inline]
fn get_metadata(&self, node: Node) -> &NodeMetadata {
self.hugr.get_metadata(node)
}

#[inline]
fn node_count(&self) -> usize {
self.graph.node_count()
Expand Down Expand Up @@ -196,36 +161,29 @@ where
}
}

impl<'a, Root, Base> HierarchyView<'a> for DescendantsGraph<'a, Root, Base>
impl<'a, Root> HierarchyView<'a> for DescendantsGraph<'a, Root>
where
Root: NodeHandle,
Base: HugrView,
{
type Base = Base;

fn new(hugr: &'a Base, root: Node) -> Self {
fn new(hugr: &'a impl HugrView, root: Node) -> Result<Self, HugrError> {
hugr.valid_node(root)?;
let root_tag = hugr.get_optype(root).tag();
if !Root::TAG.is_superset(root_tag) {
// TODO: Return an error
panic!("Root node must have the correct operation type tag.")
return Err(HugrError::InvalidNode(root));
}
Self {
let hugr = hugr.base_hugr();
Ok(Self {
root,
graph: RegionGraph::new_region(
&hugr.base_hugr().graph,
&hugr.base_hugr().hierarchy,
root.index,
),
graph: RegionGraph::new_region(&hugr.graph, &hugr.hierarchy, root.index),
hugr,
_phantom: std::marker::PhantomData,
}
})
}
}

impl<'g, Root, Base> super::sealed::HugrInternals for DescendantsGraph<'g, Root, Base>
impl<'g, Root> super::sealed::HugrInternals for DescendantsGraph<'g, Root>
where
Root: NodeHandle,
Base: HugrInternals,
{
type Portgraph<'p> = &'p RegionGraph<'g> where Self: 'p;

Expand All @@ -236,7 +194,7 @@ where

#[inline]
fn base_hugr(&self) -> &Hugr {
self.hugr.base_hugr()
self.hugr
}

#[inline]
Expand Down Expand Up @@ -299,7 +257,7 @@ pub(super) mod test {
fn full_region() -> Result<(), Box<dyn std::error::Error>> {
let (hugr, def, inner) = make_module_hgr()?;

let region: DescendantsGraph = DescendantsGraph::new(&hugr, def);
let region: DescendantsGraph = DescendantsGraph::new(&hugr, def)?;

assert_eq!(region.node_count(), 7);
assert!(region.nodes().all(|n| n == def
Expand All @@ -311,7 +269,7 @@ pub(super) mod test {
region.get_function_type(),
Some(&FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]))
);
let inner_region: DescendantsGraph = DescendantsGraph::new(&hugr, inner);
let inner_region: DescendantsGraph = DescendantsGraph::new(&hugr, inner)?;
assert_eq!(
inner_region.get_function_type(),
Some(&FunctionType::new(type_row![NAT], type_row![NAT]))
Expand Down
Loading