Skip to content

Commit

Permalink
Machine::run owns DFContext. Or &mut ?? (TODO update interpret_leaf_o…
Browse files Browse the repository at this point in the history
…p to &mut self)
  • Loading branch information
acl-cqc committed Dec 4, 2024
1 parent 071c7dd commit c5bd7b0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
6 changes: 3 additions & 3 deletions hugr-passes/src/dataflow/datalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl<H: HugrView, V: AbstractValue> Machine<H, V> {
/// or if any `in_values` are provided for a module-rooted Hugr without a function `"main"`.
pub fn run(
mut self,
context: &impl DFContext<V>,
context: impl DFContext<V>,
in_values: impl IntoIterator<Item = (IncomingPort, PartialValue<V>)>,
) -> AnalysisResults<V, H> {
let mut in_values = in_values.into_iter();
Expand Down Expand Up @@ -127,7 +127,7 @@ impl<H: HugrView, V: AbstractValue> Machine<H, V> {
}

pub(super) fn run_datalog<V: AbstractValue, H: HugrView>(
ctx: &impl DFContext<V>,
ctx: impl DFContext<V>,
hugr: H,
in_wire_value_proto: Vec<(Node, IncomingPort, PV<V>)>,
) -> AnalysisResults<V, H> {
Expand Down Expand Up @@ -187,7 +187,7 @@ pub(super) fn run_datalog<V: AbstractValue, H: HugrView>(
if !op_t.is_container(),
if let Some(sig) = op_t.dataflow_signature(),
node_in_value_row(n, vs),
if let Some(outs) = propagate_leaf_op(ctx, &hugr, *n, &vs[..], sig.output_count()),
if let Some(outs) = propagate_leaf_op(&ctx, &hugr, *n, &vs[..], sig.output_count()),
for (p, v) in (0..).map(OutgoingPort::from).zip(outs);

// DFG --------------------
Expand Down
26 changes: 13 additions & 13 deletions hugr-passes/src/dataflow/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ fn test_make_tuple() {
let v3 = builder.make_tuple([v1, v2]).unwrap();
let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap();

let results = Machine::new(&hugr).run(&TestContext, []);
let results = Machine::new(&hugr).run(TestContext, []);

let x: Value = results.try_read_wire_concrete(v3).unwrap();
assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()]));
Expand All @@ -76,7 +76,7 @@ fn test_unpack_tuple_const() {
.outputs_arr();
let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap();

let results = Machine::new(&hugr).run(&TestContext, []);
let results = Machine::new(&hugr).run(TestContext, []);

let o1_r: Value = results.try_read_wire_concrete(o1).unwrap();
assert_eq!(o1_r, Value::false_val());
Expand All @@ -102,7 +102,7 @@ fn test_tail_loop_never_iterates() {
let [tl_o] = tail_loop.outputs_arr();
let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();

let results = Machine::new(&hugr).run(&TestContext, []);
let results = Machine::new(&hugr).run(TestContext, []);

let o_r: Value = results.try_read_wire_concrete(tl_o).unwrap();
assert_eq!(o_r, r_v);
Expand Down Expand Up @@ -137,7 +137,7 @@ fn test_tail_loop_always_iterates() {
let [tl_o1, tl_o2] = tail_loop.outputs_arr();
let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();

let results = Machine::new(&hugr).run(&TestContext, []);
let results = Machine::new(&hugr).run(TestContext, []);

let o_r1 = results.read_out_wire(tl_o1).unwrap();
assert_eq!(o_r1, PartialValue::bottom());
Expand Down Expand Up @@ -175,7 +175,7 @@ fn test_tail_loop_two_iters() {
let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();
let [o_w1, o_w2] = tail_loop.outputs_arr();

let results = Machine::new(&hugr).run(&TestContext, []);
let results = Machine::new(&hugr).run(TestContext, []);

let o_r1 = results.read_out_wire(o_w1).unwrap();
assert_eq!(o_r1, pv_true_or_false());
Expand Down Expand Up @@ -238,7 +238,7 @@ fn test_tail_loop_containing_conditional() {
let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();
let [o_w1, o_w2] = tail_loop.outputs_arr();

let results = Machine::new(&hugr).run(&TestContext, []);
let results = Machine::new(&hugr).run(TestContext, []);

let o_r1 = results.read_out_wire(o_w1).unwrap();
assert_eq!(o_r1, pv_true());
Expand Down Expand Up @@ -290,7 +290,7 @@ fn test_conditional() {
2,
[PartialValue::new_variant(0, [])],
));
let results = Machine::new(&hugr).run(&TestContext, [(0.into(), arg_pv)]);
let results = Machine::new(&hugr).run(TestContext, [(0.into(), arg_pv)]);

let cond_r1: Value = results.try_read_wire_concrete(cond_o1).unwrap();
assert_eq!(cond_r1, Value::false_val());
Expand Down Expand Up @@ -388,7 +388,7 @@ fn test_cfg(
) {
let root = xor_and_cfg.root();
let results =
Machine::new(&xor_and_cfg).run(&TestContext, [(0.into(), inp0), (1.into(), inp1)]);
Machine::new(&xor_and_cfg).run(TestContext, [(0.into(), inp0), (1.into(), inp1)]);

assert_eq!(results.read_out_wire(Wire::new(root, 0)).unwrap(), out0);
assert_eq!(results.read_out_wire(Wire::new(root, 1)).unwrap(), out1);
Expand Down Expand Up @@ -422,7 +422,7 @@ fn test_call(
.finish_hugr_with_outputs([a2, b2], &EMPTY_REG)
.unwrap();

let results = Machine::new(&hugr).run(&TestContext, [(0.into(), inp0), (1.into(), inp1)]);
let results = Machine::new(&hugr).run(TestContext, [(0.into(), inp0), (1.into(), inp1)]);

let [res0, res1] = [0, 1].map(|i| results.read_out_wire(Wire::new(hugr.root(), i)).unwrap());
// The two calls alias so both results will be the same:
Expand All @@ -444,7 +444,7 @@ fn test_region() {
.finish_prelude_hugr_with_outputs(nested.outputs())
.unwrap();
let [nested_input, _] = hugr.get_io(nested.node()).unwrap();
let whole_hugr_results = Machine::new(&hugr).run(&TestContext, [(0.into(), pv_true())]);
let whole_hugr_results = Machine::new(&hugr).run(TestContext, [(0.into(), pv_true())]);
assert_eq!(
whole_hugr_results.read_out_wire(Wire::new(nested_input, 0)),
Some(pv_true())
Expand All @@ -464,7 +464,7 @@ fn test_region() {

let subview = DescendantsGraph::<DfgID>::try_new(&hugr, nested.node()).unwrap();
// Do not provide a value on the second input (constant false in the whole hugr, above)
let sub_hugr_results = Machine::new(subview).run(&TestContext, [(0.into(), pv_true())]);
let sub_hugr_results = Machine::new(subview).run(TestContext, [(0.into(), pv_true())]);
assert_eq!(
sub_hugr_results.read_out_wire(Wire::new(nested_input, 0)),
Some(pv_true())
Expand Down Expand Up @@ -512,7 +512,7 @@ fn test_module() {
let hugr = modb.finish_hugr(&EMPTY_REG).unwrap();
let [f2_inp, _] = hugr.get_io(f2.node()).unwrap();

let results_just_main = Machine::new(&hugr).run(&TestContext, [(0.into(), pv_true())]);
let results_just_main = Machine::new(&hugr).run(TestContext, [(0.into(), pv_true())]);
assert_eq!(
results_just_main.read_out_wire(Wire::new(f2_inp, 0)),
Some(PartialValue::Bottom)
Expand All @@ -533,7 +533,7 @@ fn test_module() {
let results_two_calls = {
let mut m = Machine::new(&hugr);
m.prepopulate_df_inputs(f2.node(), [(0.into(), pv_true())]);
m.run(&TestContext, [(0.into(), pv_false())])
m.run(TestContext, [(0.into(), pv_false())])
};

for call in [f2_call, main_call] {
Expand Down

0 comments on commit c5bd7b0

Please sign in to comment.