From 4429152e7cd999b9e3f14451e0d283fe5881c9eb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Dec 2024 20:31:30 +0000 Subject: [PATCH] tidy --- hugr-core/src/hugr/monomorphize.rs | 179 +++++++++++------------------ 1 file changed, 66 insertions(+), 113 deletions(-) diff --git a/hugr-core/src/hugr/monomorphize.rs b/hugr-core/src/hugr/monomorphize.rs index 4fa290507..1751ccc86 100644 --- a/hugr-core/src/hugr/monomorphize.rs +++ b/hugr-core/src/hugr/monomorphize.rs @@ -204,6 +204,8 @@ fn mangle_inner_func(outer_name: &str, inner_name: &str) -> String { #[cfg(test)] mod test { + use std::collections::HashMap; + use itertools::Itertools; use rstest::rstest; @@ -251,82 +253,57 @@ mod test { } #[test] - fn test_module() { + fn test_module() -> Result<(), Box> { let mut mb = ModuleBuilder::new(); - let doub = add_double(&mut mb); - let trip = { + let db = add_double(&mut mb); + let tr = { let tv0 = || Type::new_var_use(0, TypeBound::Copyable); let pfty = PolyFuncType::new( [TypeBound::Copyable.into()], Signature::new(tv0(), Type::new_tuple(vec![tv0(); 3])), ); - let mut fb = mb.define_function("triple", pfty).unwrap(); + let mut fb = mb.define_function("triple", pfty)?; let [elem] = fb.input_wires_arr(); - let pair = fb - .call(doub.handle(), &[tv0().into()], [elem], &PRELUDE_REGISTRY) - .unwrap(); + let pair = fb.call(db.handle(), &[tv0().into()], [elem], &PRELUDE_REGISTRY)?; let [elem1, elem2] = fb - .add_dataflow_op(UnpackTuple(vec![tv0(); 2].into()), pair.outputs()) - .unwrap() + .add_dataflow_op(UnpackTuple(vec![tv0(); 2].into()), pair.outputs())? .outputs_arr(); let tag = Tag::new(0, vec![vec![tv0(); 3].into()]); - let trip = fb.add_dataflow_op(tag, [elem1, elem2, elem]).unwrap(); - fb.finish_with_outputs(trip.outputs()).unwrap() + let trip = fb.add_dataflow_op(tag, [elem1, elem2, elem])?; + fb.finish_with_outputs(trip.outputs())? }; { - let mut fb = mb - .define_function( - "main", - Signature::new( - usize_t(), - vec![triple_type(usize_t()), triple_type(pair_type(usize_t()))], - ), - ) - .unwrap(); + let sig = Signature::new( + usize_t(), + vec![triple_type(usize_t()), triple_type(pair_type(usize_t()))], + ); + let mut fb = mb.define_function("main", sig)?; let [elem] = fb.input_wires_arr(); let [res1] = fb - .call( - trip.handle(), - &[usize_t().into()], - [elem], - &PRELUDE_REGISTRY, - ) - .unwrap() + .call(tr.handle(), &[usize_t().into()], [elem], &PRELUDE_REGISTRY)? .outputs_arr(); - let pair = fb - .call( - doub.handle(), - &[usize_t().into()], - [elem], - &PRELUDE_REGISTRY, - ) - .unwrap(); + let pair = fb.call(db.handle(), &[usize_t().into()], [elem], &PRELUDE_REGISTRY)?; + let pty = pair_type(usize_t()).into(); let [res2] = fb - .call( - trip.handle(), - &[pair_type(usize_t()).into()], - pair.outputs(), - &PRELUDE_REGISTRY, - ) - .unwrap() + .call(tr.handle(), &[pty], pair.outputs(), &PRELUDE_REGISTRY)? .outputs_arr(); - fb.finish_with_outputs([res1, res2]).unwrap(); + fb.finish_with_outputs([res1, res2])?; } - let hugr = mb.finish_hugr(&PRELUDE_REGISTRY).unwrap(); + let hugr = mb.finish_hugr(&PRELUDE_REGISTRY)?; assert_eq!( hugr.nodes() .filter(|n| hugr.get_optype(*n).is_func_defn()) .count(), 3 ); - let mono_hugr = monomorphize(hugr, &PRELUDE_REGISTRY); - mono_hugr.validate(&PRELUDE_REGISTRY).unwrap(); + let mono = monomorphize(hugr, &PRELUDE_REGISTRY); + mono.validate(&PRELUDE_REGISTRY)?; - let funcs = mono_hugr + let mut funcs = mono .nodes() - .filter_map(|n| mono_hugr.get_optype(n).as_func_defn()) - .collect_vec(); + .filter_map(|n| mono.get_optype(n).as_func_defn().map(|fd| (&fd.name, fd))) + .collect::>(); let expected_mangled_names = [ mangle_name("double", &[usize_t().into()]), mangle_name("triple", &[usize_t().into()]), @@ -334,103 +311,78 @@ mod test { mangle_name("triple", &[pair_type(usize_t()).into()]), ]; - assert_eq!( - funcs.iter().map(|fd| &fd.name).sorted().collect_vec(), - ["main", "double", "triple"] - .into_iter() - .chain(expected_mangled_names.iter().map(String::as_str)) - .sorted() - .collect_vec() - ); for n in expected_mangled_names.iter() { - let mono_fn = funcs.iter().find(|fd| &fd.name == n).unwrap(); - assert!(mono_fn.signature.params().is_empty()); + let mono_fn = funcs.remove(n).unwrap(); + assert!((*mono_fn).signature.params().is_empty()); } assert_eq!( - monomorphize(mono_hugr.clone(), &PRELUDE_REGISTRY), - mono_hugr - ); // Idempotent + funcs.keys().sorted().collect_vec(), + ["double", "main", "triple"].iter().collect_vec() + ); - let nopoly = remove_polyfuncs(mono_hugr); - let funcs = nopoly + assert_eq!(monomorphize(mono.clone(), &PRELUDE_REGISTRY), mono); // Idempotent + + let nopoly = remove_polyfuncs(mono); + let mut funcs = nopoly .nodes() - .filter_map(|n| nopoly.get_optype(n).as_func_defn()) - .collect_vec(); + .filter_map(|n| nopoly.get_optype(n).as_func_defn().map(|fd| (&fd.name, fd))) + .collect::>(); - assert_eq!( - funcs.iter().map(|fd| &fd.name).sorted().collect_vec(), - expected_mangled_names - .iter() - .chain(std::iter::once(&"main".into())) - .sorted() - .collect_vec() - ); - assert!(funcs.into_iter().all(|fd| fd.signature.params().is_empty())); + assert!(funcs.values().all(|fd| (*fd).signature.params().is_empty())); + for n in expected_mangled_names { + assert!(funcs.remove(&n).is_some()); + } + assert_eq!(funcs.keys().collect_vec(), vec![&"main"]); + Ok(()) } #[test] - fn test_flattening() { - //polyf1 contains polyf2 contains monof -> pf1 and pf1 share pf2's and they share monof + fn test_flattening() -> Result<(), Box> { + //pf1 contains pf2 contains mono_func -> pf1 and pf1 share pf2's and they share mono_func - let reg = ExtensionRegistry::try_new([int_types::EXTENSION.to_owned(), PRELUDE.to_owned()]) - .unwrap(); + let reg = + ExtensionRegistry::try_new([int_types::EXTENSION.to_owned(), PRELUDE.to_owned()])?; let tv0 = || Type::new_var_use(0, TypeBound::Any); + let pf_any = |sig: Signature| PolyFuncType::new([TypeBound::Any.into()], sig); let ity = || INT_TYPES[3].clone(); - let mut outer = FunctionBuilder::new("mainish", Signature::new(ity(), usize_t())).unwrap(); - let sig = PolyFuncType::new( - [TypeBound::Any.into()], - Signature::new(tv0(), vec![tv0(), usize_t(), usize_t()]), - ); - let mut pf1 = outer.define_function("pf1", sig).unwrap(); + let mut outer = FunctionBuilder::new("mainish", Signature::new(ity(), usize_t()))?; + let sig = pf_any(Signature::new(tv0(), vec![tv0(), usize_t(), usize_t()])); + let mut pf1 = outer.define_function("pf1", sig)?; - let sig = PolyFuncType::new( - [TypeBound::Any.into()], - Signature::new(tv0(), vec![tv0(), usize_t()]), - ); - let mut pf2 = pf1.define_function("pf2", sig).unwrap(); + let sig = pf_any(Signature::new(tv0(), vec![tv0(), usize_t()])); + let mut pf2 = pf1.define_function("pf2", sig)?; let mono_func = { - let mut mono_b = pf2 - .define_function("get_usz", Signature::new(type_row![], usize_t())) - .unwrap(); - let cst0 = mono_b.add_load_value(ConstUsize::new(1)); - mono_b.finish_with_outputs([cst0]).unwrap() + let mut fb = pf2.define_function("get_usz", Signature::new(type_row![], usize_t()))?; + let cst0 = fb.add_load_value(ConstUsize::new(1)); + fb.finish_with_outputs([cst0])? }; let pf2 = { let [inw] = pf2.input_wires_arr(); - - let [usz] = pf2 - .call(mono_func.handle(), &[], [], ®) - .unwrap() - .outputs_arr(); - - pf2.finish_with_outputs([inw, usz]).unwrap() + let [usz] = pf2.call(mono_func.handle(), &[], [], ®)?.outputs_arr(); + pf2.finish_with_outputs([inw, usz])? }; // pf1: Two calls to pf2, one depending on pf1's TypeArg, the other not let [a, u] = pf1 - .call(pf2.handle(), &[tv0().into()], pf1.input_wires(), ®) - .unwrap() + .call(pf2.handle(), &[tv0().into()], pf1.input_wires(), ®)? .outputs_arr(); let [u1, u2] = pf1 - .call(pf2.handle(), &[usize_t().into()], [u], ®) - .unwrap() + .call(pf2.handle(), &[usize_t().into()], [u], ®)? .outputs_arr(); - let pf1 = pf1.finish_with_outputs([a, u1, u2]).unwrap(); + let pf1 = pf1.finish_with_outputs([a, u1, u2])?; // Outer: two calls to pf1 with different TypeArgs let [_, u, _] = outer - .call(pf1.handle(), &[ity().into()], outer.input_wires(), ®) - .unwrap() + .call(pf1.handle(), &[ity().into()], outer.input_wires(), ®)? .outputs_arr(); let [_, u, _] = outer - .call(pf1.handle(), &[usize_t().into()], [u], ®) - .unwrap() + .call(pf1.handle(), &[usize_t().into()], [u], ®)? .outputs_arr(); - let hugr = outer.finish_hugr_with_outputs([u], ®).unwrap(); + let hugr = outer.finish_hugr_with_outputs([u], ®)?; let mono_hugr = monomorphize(hugr, ®); - mono_hugr.validate(®).unwrap(); + mono_hugr.validate(®)?; let funcs = mono_hugr .nodes() .filter_map(|n| mono_hugr.get_optype(n).as_func_defn().map(|fd| (n, fd))) @@ -454,6 +406,7 @@ mod test { assert!(fd.signature.params().is_empty()); assert!(mono_hugr.get_parent(n) == (fd.name != "mainish").then_some(mono_hugr.root())); } + Ok(()) } #[test]