Skip to content

Commit

Permalink
Store HugrView (not DFContext) in Machine
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Dec 4, 2024
1 parent 650cdfd commit 071c7dd
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 45 deletions.
51 changes: 25 additions & 26 deletions hugr-passes/src/dataflow/datalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,28 @@ use super::{
type PV<V> = PartialValue<V>;

/// Basic structure for performing an analysis. Usage:
/// 1. Get a new instance via [Self::default()]
/// 1. Make a new instance via [Self::new()]
/// 2. (Optionally) zero or more calls to [Self::prepopulate_wire] and/or
/// [Self::prepopulate_df_inputs] with initial values.
/// For example, to analyse a [Module](OpType::Module)-rooted Hugr as a library,
/// [Self::prepopulate_df_inputs] can be used on each externally-callable
/// [FuncDefn](OpType::FuncDefn) to set all inputs to [PartialValue::Top].
/// 3. Call [Self::run] to produce [AnalysisResults]
pub struct Machine<V: AbstractValue>(Vec<(Node, IncomingPort, PartialValue<V>)>);
pub struct Machine<H: HugrView, V: AbstractValue>(H, Vec<(Node, IncomingPort, PartialValue<V>)>);

/// derived-Default requires the context to be Defaultable, which is unnecessary
impl<V: AbstractValue> Default for Machine<V> {
fn default() -> Self {
Self(Default::default())
impl<H: HugrView, V: AbstractValue> Machine<H, V> {
/// Create a new Machine to analyse the given Hugr(View)
pub fn new(hugr: H) -> Self {
Self(hugr, Default::default())
}
}

impl<V: AbstractValue> Machine<V> {
impl<H: HugrView, V: AbstractValue> Machine<H, V> {
/// Provide initial values for a wire - these will be `join`d with any computed.
pub fn prepopulate_wire(&mut self, h: &impl HugrView, w: Wire, v: PartialValue<V>) {
self.0.extend(
h.linked_inputs(w.node(), w.source())
pub fn prepopulate_wire(&mut self, w: Wire, v: PartialValue<V>) {
self.1.extend(
self.0
.linked_inputs(w.node(), w.source())
.map(|(n, inp)| (n, inp, v.clone())),
);
}
Expand All @@ -49,18 +50,17 @@ impl<V: AbstractValue> Machine<V> {
/// Any out-ports of said same `Input` node, not given values by `in_values`, are set to [PartialValue::Top].
pub fn prepopulate_df_inputs(
&mut self,
h: &impl HugrView,
parent: Node,
in_values: impl IntoIterator<Item = (OutgoingPort, PartialValue<V>)>,
) {
// Put values onto out-wires of Input node
let [inp, _] = h.get_io(parent).unwrap();
let mut vals = vec![PartialValue::Top; h.signature(inp).unwrap().output_types().len()];
let [inp, _] = self.0.get_io(parent).unwrap();
let mut vals = vec![PartialValue::Top; self.0.signature(inp).unwrap().output_types().len()];
for (ip, v) in in_values {
vals[ip.index()] = v;
}
for (i, v) in vals.into_iter().enumerate() {
self.prepopulate_wire(h, Wire::new(inp, i), v);
self.prepopulate_wire(Wire::new(inp, i), v);
}
}

Expand All @@ -72,20 +72,20 @@ impl<V: AbstractValue> Machine<V> {
/// # Panics
/// May panic in various ways if the Hugr is invalid;
/// or if any `in_values` are provided for a module-rooted Hugr without a function `"main"`.
pub fn run<H: HugrView>(
pub fn run(
mut self,
context: &impl DFContext<V>,
hugr: H,
in_values: impl IntoIterator<Item = (IncomingPort, PartialValue<V>)>,
) -> AnalysisResults<V, H> {
let mut in_values = in_values.into_iter();
let root = hugr.root();
let root = self.0.root();
// Some nodes do not accept values as dataflow inputs - for these
// we must find the corresponding Input node.
let input_node_parent = match hugr.get_optype(root) {
let input_node_parent = match self.0.get_optype(root) {
OpType::Module(_) => {
let main = hugr.children(root).find(|n| {
hugr.get_optype(*n)
let main = self.0.children(root).find(|n| {
self.0
.get_optype(*n)
.as_func_defn()
.is_some_and(|f| f.name == "main")
});
Expand All @@ -103,27 +103,26 @@ impl<V: AbstractValue> Machine<V> {
// analysis must produce Top == we-know-nothing, not `V` !)
if let Some(p) = input_node_parent {
self.prepopulate_df_inputs(
&hugr,
p,
in_values.map(|(p, v)| (OutgoingPort::from(p.index()), v)),
);
} else {
// Put values onto in-wires of root node, datalog will do the rest
self.0.extend(in_values.map(|(p, v)| (root, p, v)));
self.1.extend(in_values.map(|(p, v)| (root, p, v)));
let got_inputs: HashSet<_, RandomState> = self
.0
.1
.iter()
.filter_map(|(n, p, _)| (n == &root).then_some(*p))
.collect();
for p in hugr.signature(root).unwrap_or_default().input_ports() {
for p in self.0.signature(root).unwrap_or_default().input_ports() {
if !got_inputs.contains(&p) {
self.0.push((root, p, PartialValue::Top));
self.1.push((root, p, PartialValue::Top));
}
}
}
// Note/TODO, if analysis is running on a subregion then we should do similar
// for any nonlocal edges providing values from outside the region.
run_datalog(context, hugr, self.0)
run_datalog(context, self.0, self.1)
}
}

Expand Down
35 changes: 16 additions & 19 deletions hugr-passes/src/dataflow/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ fn test_make_tuple() {
let v3 = builder.make_tuple([v1, v2]).unwrap();
let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap();

let results = Machine::default().run(&TestContext, &hugr, []);
let results = Machine::new(&hugr).run(&TestContext, []);

let x: Value = results.try_read_wire_concrete(v3).unwrap();
assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()]));
Expand All @@ -76,7 +76,7 @@ fn test_unpack_tuple_const() {
.outputs_arr();
let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap();

let results = Machine::default().run(&TestContext, &hugr, []);
let results = Machine::new(&hugr).run(&TestContext, []);

let o1_r: Value = results.try_read_wire_concrete(o1).unwrap();
assert_eq!(o1_r, Value::false_val());
Expand All @@ -102,7 +102,7 @@ fn test_tail_loop_never_iterates() {
let [tl_o] = tail_loop.outputs_arr();
let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();

let results = Machine::default().run(&TestContext, &hugr, []);
let results = Machine::new(&hugr).run(&TestContext, []);

let o_r: Value = results.try_read_wire_concrete(tl_o).unwrap();
assert_eq!(o_r, r_v);
Expand Down Expand Up @@ -137,7 +137,7 @@ fn test_tail_loop_always_iterates() {
let [tl_o1, tl_o2] = tail_loop.outputs_arr();
let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();

let results = Machine::default().run(&TestContext, &hugr, []);
let results = Machine::new(&hugr).run(&TestContext, []);

let o_r1 = results.read_out_wire(tl_o1).unwrap();
assert_eq!(o_r1, PartialValue::bottom());
Expand Down Expand Up @@ -175,7 +175,7 @@ fn test_tail_loop_two_iters() {
let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();
let [o_w1, o_w2] = tail_loop.outputs_arr();

let results = Machine::default().run(&TestContext, &hugr, []);
let results = Machine::new(&hugr).run(&TestContext, []);

let o_r1 = results.read_out_wire(o_w1).unwrap();
assert_eq!(o_r1, pv_true_or_false());
Expand Down Expand Up @@ -238,7 +238,7 @@ fn test_tail_loop_containing_conditional() {
let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();
let [o_w1, o_w2] = tail_loop.outputs_arr();

let results = Machine::default().run(&TestContext, &hugr, []);
let results = Machine::new(&hugr).run(&TestContext, []);

let o_r1 = results.read_out_wire(o_w1).unwrap();
assert_eq!(o_r1, pv_true());
Expand Down Expand Up @@ -290,7 +290,7 @@ fn test_conditional() {
2,
[PartialValue::new_variant(0, [])],
));
let results = Machine::default().run(&TestContext, &hugr, [(0.into(), arg_pv)]);
let results = Machine::new(&hugr).run(&TestContext, [(0.into(), arg_pv)]);

let cond_r1: Value = results.try_read_wire_concrete(cond_o1).unwrap();
assert_eq!(cond_r1, Value::false_val());
Expand Down Expand Up @@ -387,11 +387,8 @@ fn test_cfg(
xor_and_cfg: Hugr,
) {
let root = xor_and_cfg.root();
let results = Machine::default().run(
&TestContext,
&xor_and_cfg,
[(0.into(), inp0), (1.into(), inp1)],
);
let results =
Machine::new(&xor_and_cfg).run(&TestContext, [(0.into(), inp0), (1.into(), inp1)]);

assert_eq!(results.read_out_wire(Wire::new(root, 0)).unwrap(), out0);
assert_eq!(results.read_out_wire(Wire::new(root, 1)).unwrap(), out1);
Expand Down Expand Up @@ -425,7 +422,7 @@ fn test_call(
.finish_hugr_with_outputs([a2, b2], &EMPTY_REG)
.unwrap();

let results = Machine::default().run(&TestContext, &hugr, [(0.into(), inp0), (1.into(), inp1)]);
let results = Machine::new(&hugr).run(&TestContext, [(0.into(), inp0), (1.into(), inp1)]);

let [res0, res1] = [0, 1].map(|i| results.read_out_wire(Wire::new(hugr.root(), i)).unwrap());
// The two calls alias so both results will be the same:
Expand All @@ -447,7 +444,7 @@ fn test_region() {
.finish_prelude_hugr_with_outputs(nested.outputs())
.unwrap();
let [nested_input, _] = hugr.get_io(nested.node()).unwrap();
let whole_hugr_results = Machine::default().run(&TestContext, &hugr, [(0.into(), pv_true())]);
let whole_hugr_results = Machine::new(&hugr).run(&TestContext, [(0.into(), pv_true())]);
assert_eq!(
whole_hugr_results.read_out_wire(Wire::new(nested_input, 0)),
Some(pv_true())
Expand All @@ -467,7 +464,7 @@ fn test_region() {

let subview = DescendantsGraph::<DfgID>::try_new(&hugr, nested.node()).unwrap();
// Do not provide a value on the second input (constant false in the whole hugr, above)
let sub_hugr_results = Machine::default().run(&TestContext, subview, [(0.into(), pv_true())]);
let sub_hugr_results = Machine::new(subview).run(&TestContext, [(0.into(), pv_true())]);
assert_eq!(
sub_hugr_results.read_out_wire(Wire::new(nested_input, 0)),
Some(pv_true())
Expand Down Expand Up @@ -515,7 +512,7 @@ fn test_module() {
let hugr = modb.finish_hugr(&EMPTY_REG).unwrap();
let [f2_inp, _] = hugr.get_io(f2.node()).unwrap();

let results_just_main = Machine::default().run(&TestContext, &hugr, [(0.into(), pv_true())]);
let results_just_main = Machine::new(&hugr).run(&TestContext, [(0.into(), pv_true())]);
assert_eq!(
results_just_main.read_out_wire(Wire::new(f2_inp, 0)),
Some(PartialValue::Bottom)
Expand All @@ -534,9 +531,9 @@ fn test_module() {
}

let results_two_calls = {
let mut m = Machine::default();
m.prepopulate_df_inputs(&hugr, f2.node(), [(0.into(), pv_true())]);
m.run(&TestContext, &hugr, [(0.into(), pv_false())])
let mut m = Machine::new(&hugr);
m.prepopulate_df_inputs(f2.node(), [(0.into(), pv_true())]);
m.run(&TestContext, [(0.into(), pv_false())])
};

for call in [f2_call, main_call] {
Expand Down

0 comments on commit 071c7dd

Please sign in to comment.