From f7a6dca1918baed839aaa3951cb7afea7ac3814a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 4 Dec 2024 14:50:16 +0000 Subject: [PATCH] Move monomorphize.rs into hugr-passes, pub Substitution::new --- hugr-core/src/hugr.rs | 3 - hugr-core/src/types.rs | 8 ++- hugr-passes/src/lib.rs | 2 + .../hugr => hugr-passes/src}/monomorphize.rs | 57 ++++++++++--------- 4 files changed, 40 insertions(+), 30 deletions(-) rename {hugr-core/src/hugr => hugr-passes/src}/monomorphize.rs (91%) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index dc8321f2f..b42622745 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -4,7 +4,6 @@ pub mod hugrmut; pub(crate) mod ident; pub mod internal; -mod monomorphize; pub mod rewrite; pub mod serialize; pub mod validate; @@ -33,8 +32,6 @@ use crate::ops::{OpTag, OpTrait}; pub use crate::ops::{OpType, DEFAULT_OPTYPE}; use crate::{Direction, Node}; -pub use monomorphize::{monomorphize, remove_polyfuncs}; - /// The Hugr data structure. #[derive(Clone, Debug, PartialEq)] pub struct Hugr { diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index a8cd3d9ba..b8e3a96e4 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -547,7 +547,13 @@ impl From for TypeRV { pub struct Substitution<'a>(&'a [TypeArg], &'a ExtensionRegistry); impl<'a> Substitution<'a> { - pub(crate) fn new(items: &'a [TypeArg], exts: &'a ExtensionRegistry) -> Self { + /// Create a new Substitution given the replacement values (indexed + /// as the variables they replace). `exts` must contain the [TypeDef] + /// for every custom [Type] (to which the Substitution is applied) + /// containing a type-variable. + /// + /// [TypeDef]: crate::extension::TypeDef + pub fn new(items: &'a [TypeArg], exts: &'a ExtensionRegistry) -> Self { Self(items, exts) } diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 9042850d4..030e1b5b1 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -5,6 +5,8 @@ pub mod force_order; mod half_node; pub mod lower; pub mod merge_bbs; +mod monomorphize; +pub use monomorphize::{monomorphize, remove_polyfuncs}; pub mod nest_cfgs; pub mod non_local; pub mod validation; diff --git a/hugr-core/src/hugr/monomorphize.rs b/hugr-passes/src/monomorphize.rs similarity index 91% rename from hugr-core/src/hugr/monomorphize.rs rename to hugr-passes/src/monomorphize.rs index ac3747703..7eda5d180 100644 --- a/hugr-core/src/hugr/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -1,13 +1,13 @@ use std::collections::{hash_map::Entry, HashMap}; -use crate::{ +use hugr_core::{ extension::ExtensionRegistry, ops::{Call, FuncDefn, OpTrait}, types::{Signature, Substitution, TypeArg}, Node, }; -use super::{internal::HugrMutInternals, Hugr, HugrMut, HugrView, OpType}; +use hugr_core::hugr::{hugrmut::HugrMut, internal::HugrMutInternals, Hugr, HugrView, OpType}; /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. @@ -114,8 +114,8 @@ fn mono_scan( h.disconnect(ch, fn_inp); // No-op if copying+substituting h.connect(new_tgt, fn_out, ch, fn_inp); - *h.op_types.get_mut(ch.pg_index()) = - Call::try_new(mono_sig.into(), vec![], reg).unwrap().into(); + h.replace_op(ch, Call::try_new(mono_sig.into(), vec![], reg).unwrap()) + .unwrap(); } } @@ -132,8 +132,12 @@ fn instantiate( let outer_name = h.get_optype(poly_func).as_func_defn().unwrap().name.clone(); let mut to_scan = Vec::from_iter(h.children(poly_func)); while let Some(n) = to_scan.pop() { - if let OpType::FuncDefn(fd) = h.op_types.get_mut(n.pg_index()) { - fd.name = mangle_inner_func(&outer_name, &fd.name); + if let OpType::FuncDefn(fd) = h.get_optype(n) { + let fd = FuncDefn { + name: mangle_inner_func(&outer_name, &fd.name), + signature: fd.signature.clone(), + }; + h.replace_op(n, fd).unwrap(); h.move_after_sibling(n, poly_func); } else { to_scan.extend(h.children(n)) @@ -210,29 +214,20 @@ mod test { use std::collections::HashMap; use itertools::Itertools; - use rstest::rstest; - use crate::builder::test::simple_dfg_hugr; - use crate::builder::{ + use hugr_core::builder::{ Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, }; - use crate::extension::prelude::{usize_t, ConstUsize, UnpackTuple, PRELUDE_ID}; - use crate::extension::{ExtensionRegistry, EMPTY_REG, PRELUDE, PRELUDE_REGISTRY}; - use crate::hugr::monomorphize::mangle_inner_func; - use crate::ops::handle::{FuncID, NodeHandle}; - use crate::ops::{FuncDefn, Tag}; - use crate::std_extensions::arithmetic::int_types::{self, INT_TYPES}; - use crate::types::{PolyFuncType, Signature, Type, TypeBound, TypeRow}; - use crate::{Hugr, HugrView, Node}; - - use super::{is_polymorphic, mangle_name, monomorphize, remove_polyfuncs}; - - #[rstest] - fn test_null(simple_dfg_hugr: Hugr) { - let mono = monomorphize(simple_dfg_hugr.clone(), &EMPTY_REG); - assert_eq!(simple_dfg_hugr, mono); - } + use hugr_core::extension::prelude::{usize_t, ConstUsize, UnpackTuple, PRELUDE_ID}; + use hugr_core::extension::{ExtensionRegistry, EMPTY_REG, PRELUDE, PRELUDE_REGISTRY}; + use hugr_core::ops::handle::{FuncID, NodeHandle}; + use hugr_core::ops::{FuncDefn, Tag}; + use hugr_core::std_extensions::arithmetic::int_types::{self, INT_TYPES}; + use hugr_core::types::{PolyFuncType, Signature, Type, TypeBound, TypeRow}; + use hugr_core::{Hugr, HugrView, Node}; + + use super::{is_polymorphic, mangle_inner_func, mangle_name, monomorphize, remove_polyfuncs}; fn pair_type(ty: Type) -> Type { Type::new_tuple(vec![ty.clone(), ty]) @@ -246,6 +241,16 @@ mod test { Signature::new(ins, outs).with_extension_delta(PRELUDE_ID) } + #[test] + fn test_null() { + let dfg_builder = + DFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()])).unwrap(); + let [i1] = dfg_builder.input_wires_arr(); + let hugr = dfg_builder.finish_prelude_hugr_with_outputs([i1]).unwrap(); + let hugr2 = monomorphize(hugr.clone(), &PRELUDE_REGISTRY); + assert_eq!(hugr, hugr2); + } + #[test] fn test_module() -> Result<(), Box> { let tv0 = || Type::new_var_use(0, TypeBound::Copyable); @@ -280,7 +285,7 @@ mod test { let pair = fb.call(db.handle(), &[tv0().into()], [elem], &PRELUDE_REGISTRY)?; let [elem1, elem2] = fb - .add_dataflow_op(UnpackTuple(vec![tv0(); 2].into()), pair.outputs())? + .add_dataflow_op(UnpackTuple::new(vec![tv0(); 2].into()), pair.outputs())? .outputs_arr(); let tag = Tag::new(0, vec![vec![tv0(); 3].into()]); let trip = fb.add_dataflow_op(tag, [elem1, elem2, elem])?;