Skip to content

Commit

Permalink
Use row_contains_bottom for CFG+DFG, and augment unpack_first(=>_no_b…
Browse files Browse the repository at this point in the history
…ottom)
  • Loading branch information
acl-cqc committed Nov 19, 2024
1 parent 497686a commit e34c7be
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
17 changes: 11 additions & 6 deletions hugr-passes/src/dataflow/datalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,10 @@ pub(super) fn run_datalog<V: AbstractValue, C: DFContext<V>>(
dfg_node(n) <-- node(n), if ctx.get_optype(*n).is_dfg();

out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg),
input_child(dfg, i), in_wire_value(dfg, p, v);
input_child(dfg, i),
node_in_value_row(dfg, row),
if !row_contains_bottom(&row[..]), // Treat the DFG as a scheduling barrier
for (p, v) in row[..].iter().enumerate();

out_wire_value(dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg),
output_child(dfg, o), in_wire_value(o, p, v);
Expand All @@ -213,7 +216,7 @@ pub(super) fn run_datalog<V: AbstractValue, C: DFContext<V>>(
output_child(tl, out_n),
node_in_value_row(out_n, out_in_row), // get the whole input row for the output node...
// ...and select just what's possible for CONTINUE_TAG, if anything
if let Some(fields) = out_in_row.unpack_first(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()),
if let Some(fields) = out_in_row.unpack_first_no_bottom(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()),
for (out_p, v) in fields.enumerate();

// Output node of child region propagate to outputs of tail loop
Expand All @@ -222,7 +225,7 @@ pub(super) fn run_datalog<V: AbstractValue, C: DFContext<V>>(
output_child(tl, out_n),
node_in_value_row(out_n, out_in_row), // get the whole input row for the output node...
// ... and select just what's possible for BREAK_TAG, if anything
if let Some(fields) = out_in_row.unpack_first(TailLoop::BREAK_TAG, tailloop.just_outputs.len()),
if let Some(fields) = out_in_row.unpack_first_no_bottom(TailLoop::BREAK_TAG, tailloop.just_outputs.len()),
for (out_p, v) in fields.enumerate();

// Conditional --------------------
Expand All @@ -239,7 +242,7 @@ pub(super) fn run_datalog<V: AbstractValue, C: DFContext<V>>(
input_child(case, i_node),
node_in_value_row(cond, in_row),
let conditional = ctx.get_optype(*cond).as_conditional().unwrap(),
if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()),
if let Some(fields) = in_row.unpack_first_no_bottom(*case_index, conditional.sum_rows[*case_index].len()),
for (out_p, v) in fields.enumerate();

// outputs of case nodes propagate to outputs of conditional *if* case reachable
Expand Down Expand Up @@ -274,7 +277,9 @@ pub(super) fn run_datalog<V: AbstractValue, C: DFContext<V>>(
cfg_node(cfg),
if let Some(entry) = ctx.children(*cfg).next(),
input_child(entry, i_node),
in_wire_value(cfg, p, v);
node_in_value_row(cfg, row),
if !row_contains_bottom(&row[..]),
for (p, v) in row[..].iter().enumerate();

// In `CFG` <Node>, values fed along a control-flow edge to <Node>
// come out of Value outports of <Node>:
Expand All @@ -293,7 +298,7 @@ pub(super) fn run_datalog<V: AbstractValue, C: DFContext<V>>(
output_child(pred, out_n),
_cfg_succ_dest(cfg, succ, dest),
node_in_value_row(out_n, out_in_row),
if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()),
if let Some(fields) = out_in_row.unpack_first_no_bottom(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()),
for (out_p, v) in fields.enumerate();

// Call --------------------
Expand Down
15 changes: 9 additions & 6 deletions hugr-passes/src/dataflow/value_row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::{
use ascent::{lattice::BoundedLattice, Lattice};
use itertools::zip_eq;

use super::{AbstractValue, PartialValue};
use super::{row_contains_bottom, AbstractValue, PartialValue};

#[derive(PartialEq, Clone, Debug, Eq, Hash)]
pub(super) struct ValueRow<V>(Vec<PartialValue<V>>);
Expand All @@ -25,16 +25,19 @@ impl<V: AbstractValue> ValueRow<V> {
r
}

/// The first value in this ValueRow must be a sum;
/// returns a new ValueRow given by unpacking the elements of the specified variant of said first value,
/// then appending the rest of the values in this row.
pub fn unpack_first(
/// If the first value in this ValueRow is a sum, that might contain
/// the specified tag, then unpack the elements of that tag, append the rest
/// of this ValueRow, and if none of the elements of that row [contain bottom](PartialValue::contains_bottom),
/// return it.
/// Otherwise (if no such tag, or values contain bottom), return None.
pub fn unpack_first_no_bottom(
&self,
variant: usize,
len: usize,
) -> Option<impl Iterator<Item = PartialValue<V>>> {
let vals = self[0].variant_values(variant, len)?;
Some(vals.into_iter().chain(self.0[1..].to_owned()))
(!row_contains_bottom(vals.iter().chain(self.0[1..].iter())))
.then(|| vals.into_iter().chain(self.0[1..].to_owned()))
}
}

Expand Down

0 comments on commit e34c7be

Please sign in to comment.