Skip to content

Commit

Permalink
Datalog works on any AbstractValue; impl'd by PartialValue for a Base…
Browse files Browse the repository at this point in the history
…Value
  • Loading branch information
acl-cqc committed Oct 7, 2024
1 parent a139f9e commit 7f2a91a
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 104 deletions.
4 changes: 2 additions & 2 deletions hugr-passes/src/const_fold2/value_handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use hugr_core::ops::Value;
use hugr_core::types::Type;
use hugr_core::Node;

use crate::dataflow::AbstractValue;
use crate::dataflow::BaseValue;

#[derive(Clone, Debug)]
pub struct HashedConst {
Expand Down Expand Up @@ -85,7 +85,7 @@ impl ValueHandle {
}
}

impl AbstractValue for ValueHandle {
impl BaseValue for ValueHandle {
fn as_sum(&self) -> Option<(usize, impl Iterator<Item = Self> + '_)> {
match self.value() {
Value::Sum(Sum { tag, values, .. }) => Some((
Expand Down
18 changes: 2 additions & 16 deletions hugr-passes/src/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,16 @@
//! Dataflow analysis of Hugrs.
mod datalog;
pub use datalog::{AbstractValue, DFContext};

mod machine;
pub use machine::Machine;

mod partial_value;
pub use partial_value::{AbstractValue, PVEnum, PartialSum, PartialValue, Sum};
pub use partial_value::{BaseValue, PVEnum, PartialSum, PartialValue, Sum};

mod total_context;
pub use total_context::TotalContext;

use hugr_core::{Hugr, Node};
use std::hash::Hash;

/// Clients of the dataflow framework (particular analyses, such as constant folding)
/// must implement this trait (including providing an appropriate domain type `V`).
pub trait DFContext<V>: Clone + Eq + Hash + std::ops::Deref<Target = Hugr> {
/// Given lattice values for each input, produce lattice values for (what we know of)
/// the outputs. Returning `None` indicates nothing can be deduced.
fn interpret_leaf_op(
&self,
node: Node,
ins: &[PartialValue<V>],
) -> Option<Vec<PartialValue<V>>>;
}

#[cfg(test)]
mod test;
90 changes: 58 additions & 32 deletions hugr-passes/src/dataflow/datalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,58 @@ use std::ops::{Index, IndexMut};
use hugr_core::extension::prelude::{MakeTuple, UnpackTuple};
use hugr_core::ops::OpType;
use hugr_core::types::Signature;
use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _};

use super::partial_value::{AbstractValue, PartialValue};
use super::DFContext;

type PV<V> = super::partial_value::PartialValue<V>;
use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum IO {
Input,
Output,
}

/// Clients of the dataflow framework (particular analyses, such as constant folding)
/// must implement this trait (including providing an appropriate domain type `PV`).
pub trait DFContext<PV: AbstractValue>: Clone + Eq + Hash + std::ops::Deref<Target = Hugr> {
/// Given lattice values for each input, produce lattice values for (what we know of)
/// the outputs. Returning `None` indicates nothing can be deduced.
fn interpret_leaf_op(&self, node: Node, ins: &[PV]) -> Option<Vec<PV>>;
}

/// Values which can be the domain for dataflow analysis. Must be able to deconstructed
/// into (and constructed from) Sums as these determine control flow.
pub trait AbstractValue: BoundedLattice + Clone + Eq + Hash + std::fmt::Debug {
/// Create a new instance representing a Sum with a single known tag
/// and (recursive) representations of the elements within that tag.
fn new_variant(tag: usize, values: impl IntoIterator<Item = Self>) -> Self;

/// New instance of unit type (i.e. the only possible value, with no contents)
fn new_unit() -> Self {
Self::new_variant(0, [])
}

/// Test whether this value *might* be a Sum with the specified tag.
fn supports_tag(&self, tag: usize) -> bool;

/// If this value might be a Sum with the specified tag, return values
/// describing the elements of the Sum, otherwise `None`.
///
/// Implementations must hold the invariant that for all `x`, `tag` and `len`:
/// `x.variant_values(tag, len).is_some() == x.supports_tag(tag)`
fn variant_values(&self, tag: usize, len: usize) -> Option<Vec<Self>>;
}

ascent::ascent! {
pub(super) struct AscentProgram<V: AbstractValue, C: DFContext<V>>;
pub(super) struct AscentProgram<PV: AbstractValue, C: DFContext<PV>>;
relation context(C);
relation out_wire_value_proto(Node, OutgoingPort, PV<V>);
relation out_wire_value_proto(Node, OutgoingPort, PV);

relation node(C, Node);
relation in_wire(C, Node, IncomingPort);
relation out_wire(C, Node, OutgoingPort);
relation parent_of_node(C, Node, Node);
relation io_node(C, Node, Node, IO);
lattice out_wire_value(C, Node, OutgoingPort, PV<V>);
lattice node_in_value_row(C, Node, ValueRow<V>);
lattice in_wire_value(C, Node, IncomingPort, PV<V>);
lattice out_wire_value(C, Node, OutgoingPort, PV);
lattice node_in_value_row(C, Node, ValueRow<PV>);
lattice in_wire_value(C, Node, IncomingPort, PV);

node(c, n) <-- context(c), for n in c.nodes();

Expand Down Expand Up @@ -144,11 +170,11 @@ ascent::ascent! {

}

fn propagate_leaf_op<V: AbstractValue>(
c: &impl DFContext<V>,
fn propagate_leaf_op<PV: AbstractValue>(
c: &impl DFContext<PV>,
n: Node,
ins: &[PV<V>],
) -> Option<ValueRow<V>> {
ins: &[PV],
) -> Option<ValueRow<PV>> {
match c.get_optype(n) {
// Handle basics here. I guess (given the current interface) we could allow
// DFContext to handle these but at the least we'd want these impls to be
Expand Down Expand Up @@ -192,29 +218,29 @@ fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator<Item = OutgoingPor
// Wrap a (known-length) row of values into a lattice. Perhaps could be part of partial_value.rs?

#[derive(PartialEq, Clone, Eq, Hash)]
struct ValueRow<V>(Vec<PartialValue<V>>);
struct ValueRow<PV>(Vec<PV>);

impl<V: AbstractValue> ValueRow<V> {
impl<PV: AbstractValue> ValueRow<PV> {
pub fn new(len: usize) -> Self {
Self(vec![PartialValue::bottom(); len])
Self(vec![PV::bottom(); len])
}

pub fn single_known(len: usize, idx: usize, v: PartialValue<V>) -> Self {
pub fn single_known(len: usize, idx: usize, v: PV) -> Self {
assert!(idx < len);
let mut r = Self::new(len);
r.0[idx] = v;
r
}

pub fn iter(&self) -> impl Iterator<Item = &PartialValue<V>> {
pub fn iter(&self) -> impl Iterator<Item = &PV> {
self.0.iter()
}

pub fn unpack_first(
&self,
variant: usize,
len: usize,
) -> Option<impl Iterator<Item = PartialValue<V>> + '_> {
) -> Option<impl Iterator<Item = PV> + '_> {
self[0]
.variant_values(variant, len)
.map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned()))
Expand All @@ -225,13 +251,13 @@ impl<V: AbstractValue> ValueRow<V> {
// }
}

impl<V> FromIterator<PartialValue<V>> for ValueRow<V> {
fn from_iter<T: IntoIterator<Item = PartialValue<V>>>(iter: T) -> Self {
impl<PV> FromIterator<PV> for ValueRow<PV> {
fn from_iter<T: IntoIterator<Item = PV>>(iter: T) -> Self {
Self(iter.into_iter().collect())
}
}

impl<V: PartialEq> PartialOrd for ValueRow<V> {
impl<V: PartialEq + PartialOrd> PartialOrd for ValueRow<V> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.0.partial_cmp(&other.0)
}
Expand Down Expand Up @@ -267,30 +293,30 @@ impl<V: AbstractValue> Lattice for ValueRow<V> {
}
}

impl<V> IntoIterator for ValueRow<V> {
type Item = PartialValue<V>;
impl<PV> IntoIterator for ValueRow<PV> {
type Item = PV;

type IntoIter = <Vec<PartialValue<V>> as IntoIterator>::IntoIter;
type IntoIter = <Vec<PV> as IntoIterator>::IntoIter;

fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}

impl<V, Idx> Index<Idx> for ValueRow<V>
impl<PV, Idx> Index<Idx> for ValueRow<PV>
where
Vec<PartialValue<V>>: Index<Idx>,
Vec<PV>: Index<Idx>,
{
type Output = <Vec<PartialValue<V>> as Index<Idx>>::Output;
type Output = <Vec<PV> as Index<Idx>>::Output;

fn index(&self, index: Idx) -> &Self::Output {
self.0.index(index)
}
}

impl<V, Idx> IndexMut<Idx> for ValueRow<V>
impl<PV, Idx> IndexMut<Idx> for ValueRow<PV>
where
Vec<PartialValue<V>>: IndexMut<Idx>,
Vec<PV>: IndexMut<Idx>,
{
fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
self.0.index_mut(index)
Expand Down
19 changes: 8 additions & 11 deletions hugr-passes/src/dataflow/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@ use std::collections::HashMap;

use hugr_core::{HugrView, Node, PortIndex, Wire};

use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue};
use super::{datalog::AscentProgram, AbstractValue, DFContext};

/// Basic structure for performing an analysis. Usage:
/// 1. Get a new instance via [Self::default()]
/// 2. Zero or more [Self::propolutate_out_wires] with initial values
/// 3. Exactly one [Self::run] to do the analysis
/// 4. Results then available via [Self::read_out_wire]
pub struct Machine<V: AbstractValue, C: DFContext<V>>(
AscentProgram<V, C>,
Option<HashMap<Wire, PartialValue<V>>>,
pub struct Machine<PV: AbstractValue, C: DFContext<PV>>(
AscentProgram<PV, C>,
Option<HashMap<Wire, PV>>,
);

/// derived-Default requires the context to be Defaultable, which is unnecessary
Expand All @@ -21,13 +21,10 @@ impl<V: AbstractValue, C: DFContext<V>> Default for Machine<V, C> {
}
}

impl<V: AbstractValue, C: DFContext<V>> Machine<V, C> {
impl<PV: AbstractValue, C: DFContext<PV>> Machine<PV, C> {
/// Provide initial values for some wires.
/// (For example, if some properties of the Hugr's inputs are known.)
pub fn propolutate_out_wires(
&mut self,
wires: impl IntoIterator<Item = (Wire, PartialValue<V>)>,
) {
pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator<Item = (Wire, PV)>) {
assert!(self.1.is_none());
self.0
.out_wire_value_proto
Expand Down Expand Up @@ -55,7 +52,7 @@ impl<V: AbstractValue, C: DFContext<V>> Machine<V, C> {
}

/// Gets the lattice value computed by [Self::run] for the given wire
pub fn read_out_wire(&self, w: Wire) -> Option<PartialValue<V>> {
pub fn read_out_wire(&self, w: Wire) -> Option<PV> {
self.1.as_ref().unwrap().get(&w).cloned()
}

Expand Down Expand Up @@ -109,7 +106,7 @@ pub enum TailLoopTermination {
}

impl TailLoopTermination {
pub fn from_control_value<V: AbstractValue>(v: &PartialValue<V>) -> Self {
pub fn from_control_value(v: &impl AbstractValue) -> Self {
let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1));
if may_break && !may_continue {
Self::ExactlyZeroContinues
Expand Down
Loading

0 comments on commit 7f2a91a

Please sign in to comment.