From 73dc13cefc036e1e553c019696834e6b879a3187 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 18 Mar 2024 10:54:50 +0000 Subject: [PATCH 1/2] refactor: Extension Inference: make fewer things public, rm Meta::new (#883) I think these are generally `pub` methods of *non-pub* classes, so not generally accessible anyway, but (a) just in case and (b) makes it clearer that they aren't `pub` Also remove `pub fn new` from `Meta` as if it's not there to make things more-pub, then it really doesn't do much. --- quantinuum-hugr/src/extension/infer.rs | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/quantinuum-hugr/src/extension/infer.rs b/quantinuum-hugr/src/extension/infer.rs index 941ebb010..64fca0265 100644 --- a/quantinuum-hugr/src/extension/infer.rs +++ b/quantinuum-hugr/src/extension/infer.rs @@ -55,12 +55,6 @@ pub fn infer_extensions( #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] struct Meta(u32); -impl Meta { - pub fn new(m: u32) -> Self { - Meta(m) - } -} - #[derive(Clone, Debug, Eq, PartialEq, Hash)] /// Things we know about metavariables enum Constraint { @@ -190,7 +184,7 @@ struct UnificationContext { impl UnificationContext { /// Create a new unification context, and populate it with constraints from /// traversing the hugr which is passed in. - pub fn new(hugr: &impl HugrView) -> Self { + fn new(hugr: &impl HugrView) -> Self { let mut ctx = Self { constraints: HashMap::new(), extensions: HashMap::new(), @@ -206,7 +200,7 @@ impl UnificationContext { /// Create a fresh metavariable, and increment `fresh_name` for next time fn fresh_meta(&mut self) -> Meta { - let fresh = Meta::new(self.fresh_name); + let fresh = Meta(self.fresh_name); self.fresh_name += 1; self.constraints.insert(fresh, HashSet::new()); fresh @@ -544,7 +538,7 @@ impl UnificationContext { /// available. When there are variables, we should leave the graph as it is, /// but make sure that no matter what they're instantiated to, the graph /// still makes sense (should pass the extension validation check) - pub fn results(&self) -> Result { + fn results(&self) -> Result { // Check that all of the metavariables associated with nodes of the // graph are solved let depended_upon = { @@ -612,7 +606,7 @@ impl UnificationContext { /// where it was possible to infer them. If it wasn't possible to infer a /// *concrete* `ExtensionSet`, e.g. if the ExtensionSet relies on an open /// variable in the toplevel graph, don't include that location in the map - pub fn main_loop(&mut self) -> Result { + fn main_loop(&mut self) -> Result { let mut remaining = HashSet::::from_iter(self.constraints.keys().cloned()); // Keep going as long as we're making progress (= merging and solving nodes) @@ -674,7 +668,7 @@ impl UnificationContext { /// 2 = 1 + x, ... /// then 1 and 2 both definitely contain X, even if we don't know what else. /// So instead of instantiating to the empty set, we'll instantiate to `{X}` - pub fn instantiate_variables(&mut self) { + fn instantiate_variables(&mut self) { // A directed graph to keep track of `Plus` constraint relationships let mut relations = GraphContainer::::new(); let mut solutions: HashMap = HashMap::new(); From 9ef780864046392a49bb13f00c6fb45faeb091c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Mon, 18 Mar 2024 11:08:27 +0000 Subject: [PATCH 2/2] feat!: Return the type of FuncDecl in `HugrView::get_function_type` (#880) `get_function_type` was confusing, as it didn't specify that it only worked for dataflow containers, and required manually checking for function declarations on the side. BREAKING CHANGE: `HugrView::get_function_type` now returns a `PolyFuncType`. --- quantinuum-hugr/src/hugr/views.rs | 38 ++++++++++++++++--- quantinuum-hugr/src/hugr/views/descendants.rs | 4 +- quantinuum-hugr/src/hugr/views/sibling.rs | 2 +- quantinuum-hugr/src/types/check.rs | 16 +------- 4 files changed, 37 insertions(+), 23 deletions(-) diff --git a/quantinuum-hugr/src/hugr/views.rs b/quantinuum-hugr/src/hugr/views.rs index 0a0cf7d1b..649ae522e 100644 --- a/quantinuum-hugr/src/hugr/views.rs +++ b/quantinuum-hugr/src/hugr/views.rs @@ -28,8 +28,8 @@ use super::{Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, DEFAULT_NO use crate::ops::handle::NodeHandle; use crate::ops::{OpParent, OpTag, OpTrait, OpType}; -use crate::types::Type; use crate::types::{EdgeKind, FunctionType}; +use crate::types::{PolyFuncType, Type}; use crate::{Direction, IncomingPort, Node, OutgoingPort, Port}; use itertools::Either; @@ -327,11 +327,37 @@ pub trait HugrView: sealed::HugrInternals { } } - /// For HUGRs with a [`DataflowParent`][crate::ops::DataflowParent] root operation, report the - /// signature of the inner dataflow sibling graph. Otherwise return None. - fn get_function_type(&self) -> Option { - let op = self.get_nodetype(self.root()); - op.op.inner_function_type() + /// Returns the function type defined by this dataflow HUGR. + /// + /// If the root of the Hugr is a + /// [`DataflowParent`][crate::ops::DataflowParent] operation, report the + /// signature corresponding to the input and output node of its sibling + /// graph. Otherwise, returns `None`. + /// + /// In contrast to [`get_function_type`][HugrView::get_function_type], this + /// method always return a concrete [`FunctionType`]. + fn get_df_function_type(&self) -> Option { + let op = self.get_optype(self.root()); + op.inner_function_type() + } + + /// Returns the function type defined by this HUGR. + /// + /// For HUGRs with a [`DataflowParent`][crate::ops::DataflowParent] root + /// operation, report the signature of the inner dataflow sibling graph. + /// + /// For HUGRS with a [`FuncDecl`][crate::ops::FuncDecl] or + /// [`FuncDefn`][crate::ops::FuncDefn] root operation, report the signature + /// of the function. + /// + /// Otherwise, returns `None`. + fn get_function_type(&self) -> Option { + let op = self.get_optype(self.root()); + match op { + OpType::FuncDecl(decl) => Some(decl.signature.clone()), + OpType::FuncDefn(defn) => Some(defn.signature.clone()), + _ => op.inner_function_type().map(PolyFuncType::from), + } } /// Return a wrapper over the view that can be used in petgraph algorithms. diff --git a/quantinuum-hugr/src/hugr/views/descendants.rs b/quantinuum-hugr/src/hugr/views/descendants.rs index 3a60f4755..453503509 100644 --- a/quantinuum-hugr/src/hugr/views/descendants.rs +++ b/quantinuum-hugr/src/hugr/views/descendants.rs @@ -261,12 +261,12 @@ pub(super) mod test { assert_eq!( region.get_function_type(), - Some(FunctionType::new_endo(type_row![NAT, QB])) + Some(FunctionType::new_endo(type_row![NAT, QB]).into()) ); let inner_region: DescendantsGraph = DescendantsGraph::try_new(&hugr, inner)?; assert_eq!( inner_region.get_function_type(), - Some(FunctionType::new(type_row![NAT], type_row![NAT])) + Some(FunctionType::new(type_row![NAT], type_row![NAT]).into()) ); Ok(()) diff --git a/quantinuum-hugr/src/hugr/views/sibling.rs b/quantinuum-hugr/src/hugr/views/sibling.rs index e07897ac1..7d69cc4b0 100644 --- a/quantinuum-hugr/src/hugr/views/sibling.rs +++ b/quantinuum-hugr/src/hugr/views/sibling.rs @@ -439,7 +439,7 @@ mod test { fn flat_mut(mut simple_dfg_hugr: Hugr) { simple_dfg_hugr.update_validate(&PRELUDE_REGISTRY).unwrap(); let root = simple_dfg_hugr.root(); - let signature = simple_dfg_hugr.get_function_type().unwrap(); + let signature = simple_dfg_hugr.get_df_function_type().unwrap().clone(); let sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root); assert_eq!( diff --git a/quantinuum-hugr/src/types/check.rs b/quantinuum-hugr/src/types/check.rs index 2a97f61d7..ebc5175c6 100644 --- a/quantinuum-hugr/src/types/check.rs +++ b/quantinuum-hugr/src/types/check.rs @@ -2,11 +2,7 @@ use thiserror::Error; -use crate::{ - ops::{FuncDecl, FuncDefn, OpType}, - values::Value, - Hugr, HugrView, -}; +use crate::{values::Value, Hugr, HugrView}; use super::{CustomType, PolyFuncType, Type, TypeEnum}; @@ -57,15 +53,7 @@ pub enum ConstTypeError { fn type_sig_equal(v: &Hugr, t: &PolyFuncType) -> bool { // exact signature equality, in future this may need to be // relaxed to be compatibility checks between the signatures. - let root_op = v.get_optype(v.root()); - if let OpType::FuncDecl(FuncDecl { signature, .. }) - | OpType::FuncDefn(FuncDefn { signature, .. }) = root_op - { - signature == t - } else { - v.get_function_type() - .is_some_and(|ft| &PolyFuncType::from(ft) == t) - } + v.get_function_type().is_some_and(|ft| &ft == t) } impl super::SumType {