Skip to content

Commit

Permalink
tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Dec 2, 2024
1 parent 6ff4c3b commit 4429152
Showing 1 changed file with 66 additions and 113 deletions.
179 changes: 66 additions & 113 deletions hugr-core/src/hugr/monomorphize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ fn mangle_inner_func(outer_name: &str, inner_name: &str) -> String {

#[cfg(test)]
mod test {
use std::collections::HashMap;

use itertools::Itertools;
use rstest::rstest;

Expand Down Expand Up @@ -251,186 +253,136 @@ mod test {
}

#[test]
fn test_module() {
fn test_module() -> Result<(), Box<dyn std::error::Error>> {
let mut mb = ModuleBuilder::new();
let doub = add_double(&mut mb);
let trip = {
let db = add_double(&mut mb);
let tr = {
let tv0 = || Type::new_var_use(0, TypeBound::Copyable);
let pfty = PolyFuncType::new(
[TypeBound::Copyable.into()],
Signature::new(tv0(), Type::new_tuple(vec![tv0(); 3])),
);
let mut fb = mb.define_function("triple", pfty).unwrap();
let mut fb = mb.define_function("triple", pfty)?;
let [elem] = fb.input_wires_arr();
let pair = fb
.call(doub.handle(), &[tv0().into()], [elem], &PRELUDE_REGISTRY)
.unwrap();
let pair = fb.call(db.handle(), &[tv0().into()], [elem], &PRELUDE_REGISTRY)?;

let [elem1, elem2] = fb
.add_dataflow_op(UnpackTuple(vec![tv0(); 2].into()), pair.outputs())
.unwrap()
.add_dataflow_op(UnpackTuple(vec![tv0(); 2].into()), pair.outputs())?
.outputs_arr();
let tag = Tag::new(0, vec![vec![tv0(); 3].into()]);
let trip = fb.add_dataflow_op(tag, [elem1, elem2, elem]).unwrap();
fb.finish_with_outputs(trip.outputs()).unwrap()
let trip = fb.add_dataflow_op(tag, [elem1, elem2, elem])?;
fb.finish_with_outputs(trip.outputs())?
};
{
let mut fb = mb
.define_function(
"main",
Signature::new(
usize_t(),
vec![triple_type(usize_t()), triple_type(pair_type(usize_t()))],
),
)
.unwrap();
let sig = Signature::new(
usize_t(),
vec![triple_type(usize_t()), triple_type(pair_type(usize_t()))],
);
let mut fb = mb.define_function("main", sig)?;
let [elem] = fb.input_wires_arr();
let [res1] = fb
.call(
trip.handle(),
&[usize_t().into()],
[elem],
&PRELUDE_REGISTRY,
)
.unwrap()
.call(tr.handle(), &[usize_t().into()], [elem], &PRELUDE_REGISTRY)?
.outputs_arr();
let pair = fb
.call(
doub.handle(),
&[usize_t().into()],
[elem],
&PRELUDE_REGISTRY,
)
.unwrap();
let pair = fb.call(db.handle(), &[usize_t().into()], [elem], &PRELUDE_REGISTRY)?;
let pty = pair_type(usize_t()).into();
let [res2] = fb
.call(
trip.handle(),
&[pair_type(usize_t()).into()],
pair.outputs(),
&PRELUDE_REGISTRY,
)
.unwrap()
.call(tr.handle(), &[pty], pair.outputs(), &PRELUDE_REGISTRY)?
.outputs_arr();
fb.finish_with_outputs([res1, res2]).unwrap();
fb.finish_with_outputs([res1, res2])?;
}
let hugr = mb.finish_hugr(&PRELUDE_REGISTRY).unwrap();
let hugr = mb.finish_hugr(&PRELUDE_REGISTRY)?;
assert_eq!(
hugr.nodes()
.filter(|n| hugr.get_optype(*n).is_func_defn())
.count(),
3
);
let mono_hugr = monomorphize(hugr, &PRELUDE_REGISTRY);
mono_hugr.validate(&PRELUDE_REGISTRY).unwrap();
let mono = monomorphize(hugr, &PRELUDE_REGISTRY);
mono.validate(&PRELUDE_REGISTRY)?;

let funcs = mono_hugr
let mut funcs = mono
.nodes()
.filter_map(|n| mono_hugr.get_optype(n).as_func_defn())
.collect_vec();
.filter_map(|n| mono.get_optype(n).as_func_defn().map(|fd| (&fd.name, fd)))
.collect::<HashMap<_, _>>();
let expected_mangled_names = [
mangle_name("double", &[usize_t().into()]),
mangle_name("triple", &[usize_t().into()]),
mangle_name("double", &[pair_type(usize_t()).into()]),
mangle_name("triple", &[pair_type(usize_t()).into()]),
];

assert_eq!(
funcs.iter().map(|fd| &fd.name).sorted().collect_vec(),
["main", "double", "triple"]
.into_iter()
.chain(expected_mangled_names.iter().map(String::as_str))
.sorted()
.collect_vec()
);
for n in expected_mangled_names.iter() {
let mono_fn = funcs.iter().find(|fd| &fd.name == n).unwrap();
assert!(mono_fn.signature.params().is_empty());
let mono_fn = funcs.remove(n).unwrap();
assert!((*mono_fn).signature.params().is_empty());
}

assert_eq!(
monomorphize(mono_hugr.clone(), &PRELUDE_REGISTRY),
mono_hugr
); // Idempotent
funcs.keys().sorted().collect_vec(),
["double", "main", "triple"].iter().collect_vec()
);

let nopoly = remove_polyfuncs(mono_hugr);
let funcs = nopoly
assert_eq!(monomorphize(mono.clone(), &PRELUDE_REGISTRY), mono); // Idempotent

let nopoly = remove_polyfuncs(mono);
let mut funcs = nopoly
.nodes()
.filter_map(|n| nopoly.get_optype(n).as_func_defn())
.collect_vec();
.filter_map(|n| nopoly.get_optype(n).as_func_defn().map(|fd| (&fd.name, fd)))
.collect::<HashMap<_, _>>();

assert_eq!(
funcs.iter().map(|fd| &fd.name).sorted().collect_vec(),
expected_mangled_names
.iter()
.chain(std::iter::once(&"main".into()))
.sorted()
.collect_vec()
);
assert!(funcs.into_iter().all(|fd| fd.signature.params().is_empty()));
assert!(funcs.values().all(|fd| (*fd).signature.params().is_empty()));
for n in expected_mangled_names {
assert!(funcs.remove(&n).is_some());
}
assert_eq!(funcs.keys().collect_vec(), vec![&"main"]);
Ok(())
}

#[test]
fn test_flattening() {
//polyf1 contains polyf2 contains monof -> pf1<a> and pf1<b> share pf2's and they share monof
fn test_flattening() -> Result<(), Box<dyn std::error::Error>> {
//pf1 contains pf2 contains mono_func -> pf1<a> and pf1<b> share pf2's and they share mono_func

let reg = ExtensionRegistry::try_new([int_types::EXTENSION.to_owned(), PRELUDE.to_owned()])
.unwrap();
let reg =
ExtensionRegistry::try_new([int_types::EXTENSION.to_owned(), PRELUDE.to_owned()])?;
let tv0 = || Type::new_var_use(0, TypeBound::Any);
let pf_any = |sig: Signature| PolyFuncType::new([TypeBound::Any.into()], sig);
let ity = || INT_TYPES[3].clone();

let mut outer = FunctionBuilder::new("mainish", Signature::new(ity(), usize_t())).unwrap();
let sig = PolyFuncType::new(
[TypeBound::Any.into()],
Signature::new(tv0(), vec![tv0(), usize_t(), usize_t()]),
);
let mut pf1 = outer.define_function("pf1", sig).unwrap();
let mut outer = FunctionBuilder::new("mainish", Signature::new(ity(), usize_t()))?;
let sig = pf_any(Signature::new(tv0(), vec![tv0(), usize_t(), usize_t()]));
let mut pf1 = outer.define_function("pf1", sig)?;

let sig = PolyFuncType::new(
[TypeBound::Any.into()],
Signature::new(tv0(), vec![tv0(), usize_t()]),
);
let mut pf2 = pf1.define_function("pf2", sig).unwrap();
let sig = pf_any(Signature::new(tv0(), vec![tv0(), usize_t()]));
let mut pf2 = pf1.define_function("pf2", sig)?;

let mono_func = {
let mut mono_b = pf2
.define_function("get_usz", Signature::new(type_row![], usize_t()))
.unwrap();
let cst0 = mono_b.add_load_value(ConstUsize::new(1));
mono_b.finish_with_outputs([cst0]).unwrap()
let mut fb = pf2.define_function("get_usz", Signature::new(type_row![], usize_t()))?;
let cst0 = fb.add_load_value(ConstUsize::new(1));
fb.finish_with_outputs([cst0])?
};
let pf2 = {
let [inw] = pf2.input_wires_arr();

let [usz] = pf2
.call(mono_func.handle(), &[], [], &reg)
.unwrap()
.outputs_arr();

pf2.finish_with_outputs([inw, usz]).unwrap()
let [usz] = pf2.call(mono_func.handle(), &[], [], &reg)?.outputs_arr();
pf2.finish_with_outputs([inw, usz])?
};
// pf1: Two calls to pf2, one depending on pf1's TypeArg, the other not
let [a, u] = pf1
.call(pf2.handle(), &[tv0().into()], pf1.input_wires(), &reg)
.unwrap()
.call(pf2.handle(), &[tv0().into()], pf1.input_wires(), &reg)?
.outputs_arr();
let [u1, u2] = pf1
.call(pf2.handle(), &[usize_t().into()], [u], &reg)
.unwrap()
.call(pf2.handle(), &[usize_t().into()], [u], &reg)?
.outputs_arr();
let pf1 = pf1.finish_with_outputs([a, u1, u2]).unwrap();
let pf1 = pf1.finish_with_outputs([a, u1, u2])?;
// Outer: two calls to pf1 with different TypeArgs
let [_, u, _] = outer
.call(pf1.handle(), &[ity().into()], outer.input_wires(), &reg)
.unwrap()
.call(pf1.handle(), &[ity().into()], outer.input_wires(), &reg)?
.outputs_arr();
let [_, u, _] = outer
.call(pf1.handle(), &[usize_t().into()], [u], &reg)
.unwrap()
.call(pf1.handle(), &[usize_t().into()], [u], &reg)?
.outputs_arr();
let hugr = outer.finish_hugr_with_outputs([u], &reg).unwrap();
let hugr = outer.finish_hugr_with_outputs([u], &reg)?;

let mono_hugr = monomorphize(hugr, &reg);
mono_hugr.validate(&reg).unwrap();
mono_hugr.validate(&reg)?;
let funcs = mono_hugr
.nodes()
.filter_map(|n| mono_hugr.get_optype(n).as_func_defn().map(|fd| (n, fd)))
Expand All @@ -454,6 +406,7 @@ mod test {
assert!(fd.signature.params().is_empty());
assert!(mono_hugr.get_parent(n) == (fd.name != "mainish").then_some(mono_hugr.root()));
}
Ok(())
}

#[test]
Expand Down

0 comments on commit 4429152

Please sign in to comment.