diff --git a/hugr-core/src/hugr/monomorphize.rs b/hugr-core/src/hugr/monomorphize.rs index 9c2d6e56e..8a83850d4 100644 --- a/hugr-core/src/hugr/monomorphize.rs +++ b/hugr-core/src/hugr/monomorphize.rs @@ -209,17 +209,19 @@ mod test { use crate::builder::test::simple_dfg_hugr; use crate::builder::{ - BuildHandle, Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder, + BuildHandle, Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, + HugrBuilder, ModuleBuilder, }; - use crate::extension::prelude::{UnpackTuple, USIZE_T}; - use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY}; - use crate::hugr::monomorphize::remove_polyfuncs; + use crate::extension::prelude::{ConstUsize, UnpackTuple, USIZE_T}; + use crate::extension::{ExtensionRegistry, EMPTY_REG, PRELUDE, PRELUDE_REGISTRY}; + use crate::hugr::monomorphize::mangle_inner_func; use crate::ops::handle::FuncID; use crate::ops::Tag; + use crate::std_extensions::arithmetic::int_types::{self, INT_TYPES}; use crate::types::{PolyFuncType, Signature, Type, TypeBound}; - use crate::{Hugr, HugrView}; + use crate::{type_row, Hugr, HugrView}; - use super::{mangle_name, monomorphize}; + use super::{mangle_name, monomorphize, remove_polyfuncs}; #[rstest] fn test_null(simple_dfg_hugr: Hugr) { @@ -358,8 +360,100 @@ mod test { } #[test] - fn test_flattening() {} + fn test_flattening() { + //polyf1 contains polyf2 contains monof -> pf1 and pf1 share pf2's and they share monof + + let reg = ExtensionRegistry::try_new([int_types::EXTENSION.to_owned(), PRELUDE.to_owned()]) + .unwrap(); + let tv0 = || Type::new_var_use(0, TypeBound::Any); + let ity = || INT_TYPES[3].clone(); + + let mut outer = FunctionBuilder::new("mainish", Signature::new(ity(), USIZE_T)).unwrap(); + let sig = PolyFuncType::new( + [TypeBound::Any.into()], + Signature::new(tv0(), vec![tv0(), USIZE_T, USIZE_T]), + ); + let mut pf1 = outer.define_function("pf1", sig).unwrap(); + + let sig = PolyFuncType::new( + [TypeBound::Any.into()], + Signature::new(tv0(), vec![tv0(), USIZE_T]), + ); + let mut pf2 = pf1.define_function("pf2", sig).unwrap(); + + let mono_func = { + let mut mono_b = pf2 + .define_function("get_usz", Signature::new(type_row![], USIZE_T)) + .unwrap(); + let cst0 = mono_b.add_load_value(ConstUsize::new(1)); + mono_b.finish_with_outputs([cst0]).unwrap() + }; + let pf2 = { + let [inw] = pf2.input_wires_arr(); + + let [usz] = pf2 + .call(mono_func.handle(), &[], [], ®) + .unwrap() + .outputs_arr(); + + pf2.finish_with_outputs([inw, usz]).unwrap() + }; + // pf1: Two calls to pf2, one depending on pf1's TypeArg, the other not + let [a, u] = pf1 + .call(pf2.handle(), &[tv0().into()], pf1.input_wires(), ®) + .unwrap() + .outputs_arr(); + let [u1, u2] = pf1 + .call(pf2.handle(), &[USIZE_T.into()], [u], ®) + .unwrap() + .outputs_arr(); + let pf1 = pf1.finish_with_outputs([a, u1, u2]).unwrap(); + // Outer: two calls to pf1 with different TypeArgs + let [_, u, _] = outer + .call(pf1.handle(), &[ity().into()], outer.input_wires(), ®) + .unwrap() + .outputs_arr(); + let [_, u, _] = outer + .call(pf1.handle(), &[USIZE_T.into()], [u], ®) + .unwrap() + .outputs_arr(); + let hugr = outer.finish_hugr_with_outputs([u], ®).unwrap(); + + let mono_hugr = monomorphize(hugr, ®); + mono_hugr.validate(®).unwrap(); + let funcs = mono_hugr + .nodes() + .filter_map(|n| mono_hugr.get_optype(n).as_func_defn().map(|fd| (n, fd))) + .collect_vec(); + let pf2_name = mangle_inner_func("pf1", "pf2"); + assert_eq!( + funcs.iter().map(|(_, fd)| &fd.name).sorted().collect_vec(), + vec![ + &mangle_name("pf1", &[ity().into()]), + &mangle_name("pf1", &[USIZE_T.into()]), + &mangle_name(&pf2_name, &[ity().into()]), // from pf1 + &mangle_name(&pf2_name, &[USIZE_T.into()]), // from pf1 and (2*)pf1 + &mangle_inner_func(&pf2_name, "get_usz"), + "mainish" + ] + .into_iter() + .sorted() + .collect_vec() + ); + for (n, fd) in funcs { + assert!(fd.signature.params().is_empty()); + assert!(mono_hugr.get_parent(n) == (fd.name != "mainish").then_some(mono_hugr.root())); + } + } #[test] - fn test_recursive() {} + fn test_not_flattened() { + //monof2 contains polyf3 (and instantiates) - not moved + //polyf4 contains polyf5 but not instantiated -> not moved + } + + #[test] + fn test_recursive() { + // make map, + } }