diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index b814b6440..8dfd98081 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -87,7 +87,7 @@ impl Machine { context .get_optype(*n) .as_func_defn() - .is_some_and(|f| f.name() == "main") + .is_some_and(|f| f.name == "main") }); if main.is_none() && in_values.next().is_some() { panic!("Cannot give inputs to module with no 'main'"); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index c0fbf395a..dafdb8046 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,6 +1,6 @@ use ascent::{lattice::BoundedLattice, Lattice}; -use hugr_core::builder::{CFGBuilder, Container, DataflowHugr}; +use hugr_core::builder::{CFGBuilder, Container, DataflowHugr, ModuleBuilder}; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::handle::DfgID; use hugr_core::ops::TailLoop; @@ -483,3 +483,70 @@ fn test_region() { ); } } + +#[test] +fn test_module() { + let mut modb = ModuleBuilder::new(); + let leaf_fn = modb + .define_function("leaf", Signature::new_endo(type_row![BOOL_T; 2])) + .unwrap(); + let outs = leaf_fn.input_wires(); + let leaf_fn = leaf_fn.finish_with_outputs(outs).unwrap(); + + let mut f2 = modb + .define_function("f2", Signature::new(BOOL_T, type_row![BOOL_T; 2])) + .unwrap(); + let [inp] = f2.input_wires_arr(); + let cst_true = f2.add_load_value(Value::true_val()); + let f2_call = f2 + .call(&leaf_fn.handle(), &[], [inp, cst_true], &EMPTY_REG) + .unwrap(); + let f2 = f2.finish_with_outputs(f2_call.outputs()).unwrap(); + + let mut main = modb + .define_function("main", Signature::new(BOOL_T, type_row![BOOL_T; 2])) + .unwrap(); + let [inp] = main.input_wires_arr(); + let cst_false = main.add_load_value(Value::false_val()); + let main_call = main + .call(&leaf_fn.handle(), &[], [inp, cst_false], &EMPTY_REG) + .unwrap(); + main.finish_with_outputs(main_call.outputs()).unwrap(); + let hugr = modb.finish_hugr(&EMPTY_REG).unwrap(); + let [f2_inp, _] = hugr.get_io(f2.node()).unwrap(); + + let results_just_main = Machine::default().run(TestContext(&hugr), [(0.into(), pv_true())]); + assert_eq!( + results_just_main.read_out_wire(Wire::new(f2_inp, 0)), + Some(PartialValue::Bottom) + ); + for call in [f2_call, main_call] { + // The first output of the Call comes from `main` because no value was fed in from f2 + assert_eq!( + results_just_main.read_out_wire(Wire::new(call.node(), 0)), + Some(pv_true().into()) + ); + // (Without reachability) the second output of the Call is the join of the two constant inputs from the two calls + assert_eq!( + results_just_main.read_out_wire(Wire::new(call.node(), 1)), + Some(pv_true_or_false().into()) + ); + } + + let results_two_calls = { + let mut m = Machine::default(); + m.prepopulate_df_inputs(&hugr, f2.node(), [(0.into(), pv_true())]); + m.run(TestContext(&hugr), [(0.into(), pv_false())]) + }; + + for call in [f2_call, main_call] { + assert_eq!( + results_two_calls.read_out_wire(Wire::new(call.node(), 0)), + Some(pv_true_or_false().into()) + ); + assert_eq!( + results_two_calls.read_out_wire(Wire::new(call.node(), 1)), + Some(pv_true_or_false().into()) + ); + } +}