diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 74e4a0b9e..74dfc68ac 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -19,27 +19,28 @@ use super::{ type PV = PartialValue; /// Basic structure for performing an analysis. Usage: -/// 1. Get a new instance via [Self::default()] +/// 1. Make a new instance via [Self::new()] /// 2. (Optionally) zero or more calls to [Self::prepopulate_wire] and/or /// [Self::prepopulate_df_inputs] with initial values. /// For example, to analyse a [Module](OpType::Module)-rooted Hugr as a library, /// [Self::prepopulate_df_inputs] can be used on each externally-callable /// [FuncDefn](OpType::FuncDefn) to set all inputs to [PartialValue::Top]. /// 3. Call [Self::run] to produce [AnalysisResults] -pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); +pub struct Machine(H, Vec<(Node, IncomingPort, PartialValue)>); -/// derived-Default requires the context to be Defaultable, which is unnecessary -impl Default for Machine { - fn default() -> Self { - Self(Default::default()) +impl Machine { + /// Create a new Machine to analyse the given Hugr(View) + pub fn new(hugr: H) -> Self { + Self(hugr, Default::default()) } } -impl Machine { +impl Machine { /// Provide initial values for a wire - these will be `join`d with any computed. - pub fn prepopulate_wire(&mut self, h: &impl HugrView, w: Wire, v: PartialValue) { - self.0.extend( - h.linked_inputs(w.node(), w.source()) + pub fn prepopulate_wire(&mut self, w: Wire, v: PartialValue) { + self.1.extend( + self.0 + .linked_inputs(w.node(), w.source()) .map(|(n, inp)| (n, inp, v.clone())), ); } @@ -49,18 +50,17 @@ impl Machine { /// Any out-ports of said same `Input` node, not given values by `in_values`, are set to [PartialValue::Top]. pub fn prepopulate_df_inputs( &mut self, - h: &impl HugrView, parent: Node, in_values: impl IntoIterator)>, ) { // Put values onto out-wires of Input node - let [inp, _] = h.get_io(parent).unwrap(); - let mut vals = vec![PartialValue::Top; h.signature(inp).unwrap().output_types().len()]; + let [inp, _] = self.0.get_io(parent).unwrap(); + let mut vals = vec![PartialValue::Top; self.0.signature(inp).unwrap().output_types().len()]; for (ip, v) in in_values { vals[ip.index()] = v; } for (i, v) in vals.into_iter().enumerate() { - self.prepopulate_wire(h, Wire::new(inp, i), v); + self.prepopulate_wire(Wire::new(inp, i), v); } } @@ -72,20 +72,20 @@ impl Machine { /// # Panics /// May panic in various ways if the Hugr is invalid; /// or if any `in_values` are provided for a module-rooted Hugr without a function `"main"`. - pub fn run( + pub fn run( mut self, context: &impl DFContext, - hugr: H, in_values: impl IntoIterator)>, ) -> AnalysisResults { let mut in_values = in_values.into_iter(); - let root = hugr.root(); + let root = self.0.root(); // Some nodes do not accept values as dataflow inputs - for these // we must find the corresponding Input node. - let input_node_parent = match hugr.get_optype(root) { + let input_node_parent = match self.0.get_optype(root) { OpType::Module(_) => { - let main = hugr.children(root).find(|n| { - hugr.get_optype(*n) + let main = self.0.children(root).find(|n| { + self.0 + .get_optype(*n) .as_func_defn() .is_some_and(|f| f.name == "main") }); @@ -103,27 +103,26 @@ impl Machine { // analysis must produce Top == we-know-nothing, not `V` !) if let Some(p) = input_node_parent { self.prepopulate_df_inputs( - &hugr, p, in_values.map(|(p, v)| (OutgoingPort::from(p.index()), v)), ); } else { // Put values onto in-wires of root node, datalog will do the rest - self.0.extend(in_values.map(|(p, v)| (root, p, v))); + self.1.extend(in_values.map(|(p, v)| (root, p, v))); let got_inputs: HashSet<_, RandomState> = self - .0 + .1 .iter() .filter_map(|(n, p, _)| (n == &root).then_some(*p)) .collect(); - for p in hugr.signature(root).unwrap_or_default().input_ports() { + for p in self.0.signature(root).unwrap_or_default().input_ports() { if !got_inputs.contains(&p) { - self.0.push((root, p, PartialValue::Top)); + self.1.push((root, p, PartialValue::Top)); } } } // Note/TODO, if analysis is running on a subregion then we should do similar // for any nonlocal edges providing values from outside the region. - run_datalog(context, hugr, self.0) + run_datalog(context, self.0, self.1) } } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 7889b364b..8f9d18bf9 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -60,7 +60,7 @@ fn test_make_tuple() { let v3 = builder.make_tuple([v1, v2]).unwrap(); let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::new(&hugr).run(&TestContext, []); let x: Value = results.try_read_wire_concrete(v3).unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); @@ -76,7 +76,7 @@ fn test_unpack_tuple_const() { .outputs_arr(); let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::new(&hugr).run(&TestContext, []); let o1_r: Value = results.try_read_wire_concrete(o1).unwrap(); assert_eq!(o1_r, Value::false_val()); @@ -102,7 +102,7 @@ fn test_tail_loop_never_iterates() { let [tl_o] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::new(&hugr).run(&TestContext, []); let o_r: Value = results.try_read_wire_concrete(tl_o).unwrap(); assert_eq!(o_r, r_v); @@ -137,7 +137,7 @@ fn test_tail_loop_always_iterates() { let [tl_o1, tl_o2] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::new(&hugr).run(&TestContext, []); let o_r1 = results.read_out_wire(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom()); @@ -175,7 +175,7 @@ fn test_tail_loop_two_iters() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::new(&hugr).run(&TestContext, []); let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true_or_false()); @@ -238,7 +238,7 @@ fn test_tail_loop_containing_conditional() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::new(&hugr).run(&TestContext, []); let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true()); @@ -290,7 +290,7 @@ fn test_conditional() { 2, [PartialValue::new_variant(0, [])], )); - let results = Machine::default().run(&TestContext, &hugr, [(0.into(), arg_pv)]); + let results = Machine::new(&hugr).run(&TestContext, [(0.into(), arg_pv)]); let cond_r1: Value = results.try_read_wire_concrete(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); @@ -387,11 +387,8 @@ fn test_cfg( xor_and_cfg: Hugr, ) { let root = xor_and_cfg.root(); - let results = Machine::default().run( - &TestContext, - &xor_and_cfg, - [(0.into(), inp0), (1.into(), inp1)], - ); + let results = + Machine::new(&xor_and_cfg).run(&TestContext, [(0.into(), inp0), (1.into(), inp1)]); assert_eq!(results.read_out_wire(Wire::new(root, 0)).unwrap(), out0); assert_eq!(results.read_out_wire(Wire::new(root, 1)).unwrap(), out1); @@ -425,7 +422,7 @@ fn test_call( .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) .unwrap(); - let results = Machine::default().run(&TestContext, &hugr, [(0.into(), inp0), (1.into(), inp1)]); + let results = Machine::new(&hugr).run(&TestContext, [(0.into(), inp0), (1.into(), inp1)]); let [res0, res1] = [0, 1].map(|i| results.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); // The two calls alias so both results will be the same: @@ -447,7 +444,7 @@ fn test_region() { .finish_prelude_hugr_with_outputs(nested.outputs()) .unwrap(); let [nested_input, _] = hugr.get_io(nested.node()).unwrap(); - let whole_hugr_results = Machine::default().run(&TestContext, &hugr, [(0.into(), pv_true())]); + let whole_hugr_results = Machine::new(&hugr).run(&TestContext, [(0.into(), pv_true())]); assert_eq!( whole_hugr_results.read_out_wire(Wire::new(nested_input, 0)), Some(pv_true()) @@ -467,7 +464,7 @@ fn test_region() { let subview = DescendantsGraph::::try_new(&hugr, nested.node()).unwrap(); // Do not provide a value on the second input (constant false in the whole hugr, above) - let sub_hugr_results = Machine::default().run(&TestContext, subview, [(0.into(), pv_true())]); + let sub_hugr_results = Machine::new(subview).run(&TestContext, [(0.into(), pv_true())]); assert_eq!( sub_hugr_results.read_out_wire(Wire::new(nested_input, 0)), Some(pv_true()) @@ -515,7 +512,7 @@ fn test_module() { let hugr = modb.finish_hugr(&EMPTY_REG).unwrap(); let [f2_inp, _] = hugr.get_io(f2.node()).unwrap(); - let results_just_main = Machine::default().run(&TestContext, &hugr, [(0.into(), pv_true())]); + let results_just_main = Machine::new(&hugr).run(&TestContext, [(0.into(), pv_true())]); assert_eq!( results_just_main.read_out_wire(Wire::new(f2_inp, 0)), Some(PartialValue::Bottom) @@ -534,9 +531,9 @@ fn test_module() { } let results_two_calls = { - let mut m = Machine::default(); - m.prepopulate_df_inputs(&hugr, f2.node(), [(0.into(), pv_true())]); - m.run(&TestContext, &hugr, [(0.into(), pv_false())]) + let mut m = Machine::new(&hugr); + m.prepopulate_df_inputs(f2.node(), [(0.into(), pv_true())]); + m.run(&TestContext, [(0.into(), pv_false())]) }; for call in [f2_call, main_call] {