Skip to content

Commit

Permalink
Fix/make-compile total_context.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Oct 21, 2024
1 parent 0b71236 commit 502d4a2
Showing 1 changed file with 28 additions and 21 deletions.
49 changes: 28 additions & 21 deletions hugr-passes/src/dataflow/total_context.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
use std::hash::Hash;

use hugr_core::ops::ExtensionOp;
use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex};

use super::partial_value::{AbstractValue, PartialValue, ValueOrSum};
use super::partial_value::{AbstractValue, PartialValue, Sum};
use super::DFContext;

/// A simpler interface like [DFContext] but where the context only cares about
/// values that are completely known (in the lattice `V`)
/// rather than e.g. Sums potentially of two variants each of known values.
/// values that are completely known (in the lattice `V`) rather than partially
/// (e.g. no [PartialSum]s of more than one variant, no top/bottom)
pub trait TotalContext<V>: Clone + Eq + Hash + std::ops::Deref<Target = Hugr> {
type InterpretableVal: TryFrom<ValueOrSum<V>>;
/// Representation of a (single, non-partial) value usable for interpretation
type InterpretableVal: From<V> + TryFrom<Sum<Self::InterpretableVal>>;

/// Interpret an (extension) operation given total values for some of the in-ports
/// `ins` will be a non-empty slice with distinct [IncomingPort]s.
fn interpret_leaf_op(
&self,
node: Node,
e: &ExtensionOp,
ins: &[(IncomingPort, Self::InterpretableVal)],
) -> Vec<(OutgoingPort, V)>;
}
Expand All @@ -21,32 +27,33 @@ impl<V: AbstractValue, T: TotalContext<V>> DFContext<V> for T {
fn interpret_leaf_op(
&self,
node: Node,
e: &ExtensionOp,
ins: &[PartialValue<V>],
) -> Option<Vec<PartialValue<V>>> {
outs: &mut [PartialValue<V>],
) {
let op = self.get_optype(node);
let sig = op.dataflow_signature()?;
let Some(sig) = op.dataflow_signature() else {
return;
};
let known_ins = sig
.input_types()
.iter()
.enumerate()
.zip(ins.iter())
.filter_map(|((i, ty), pv)| {
pv.clone()
.try_into_value(ty)
// Discard PVs which don't produce ValueOrSum, i.e. Bottom/Top :-)
.ok()
// And discard any ValueOrSum that don't produce V - this is a bit silent :-(
.and_then(|v_s| T::InterpretableVal::try_from(v_s).ok())
.map(|v| (IncomingPort::from(i), v))
let v = match pv {
PartialValue::Bottom | PartialValue::Top => None,
PartialValue::Value(v) => Some(v.clone().into()),
PartialValue::PartialSum(ps) => T::InterpretableVal::try_from(
ps.clone().try_into_value::<T::InterpretableVal>(ty).ok()?,
)
.ok(),
}?;
Some((IncomingPort::from(i), v))
})
.collect::<Vec<_>>();
let known_outs = self.interpret_leaf_op(node, &known_ins);
(!known_outs.is_empty()).then(|| {
let mut res = vec![PartialValue::Bottom; sig.output_count()];
for (p, v) in known_outs {
res[p.index()] = v.into();
}
res
})
for (p, v) in self.interpret_leaf_op(node, e, &known_ins) {
outs[p.index()] = PartialValue::Value(v);
}
}
}

0 comments on commit 502d4a2

Please sign in to comment.