Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lmondada committed Jul 28, 2023
1 parent bdd6e00 commit ad0f464
Showing 1 changed file with 64 additions and 24 deletions.
88 changes: 64 additions & 24 deletions src/algorithms/convex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,25 @@
//! (linear in the size of the graph), but can be reused to check multiple
//! subgraphs for convexity quickly.
use std::collections::{BTreeMap, BTreeSet};
use std::collections::BTreeSet;

use bitvec::bitvec;
use bitvec::vec::BitVec;

use crate::algorithms::toposort;
use crate::{Direction, LinkView, NodeIndex, PortIndex};
use crate::{Direction, LinkView, NodeIndex, PortIndex, SecondaryMap, UnmanagedDenseMap};

use super::TopoSort;

#[derive(Default, Clone, Debug, PartialEq, Eq)]
enum Causal {
#[default]
P, // in the past
F, // in the future
}

/// A pre-computed datastructure for fast convexity checking.
///
/// TODO: implement for graph traits?
pub struct ConvexChecker<G> {
graph: G,
// The nodes in topological order
topsort_nodes: Vec<NodeIndex>,
// The index of a node in the topological order (the inverse of topsort_nodes)
topsort_ind: BTreeMap<NodeIndex, usize>,
topsort_ind: UnmanagedDenseMap<NodeIndex, usize>,
// A temporary datastructure used during `is_convex`
causal: Vec<Causal>,
causal: CausalVec,
}

impl<G> ConvexChecker<G>
Expand All @@ -40,9 +34,11 @@ where
let inputs = graph.nodes_iter().filter(|&n| graph.num_inputs(n) == 0);
let topsort: TopoSort<_> = toposort(graph, inputs, Direction::Outgoing);
let topsort_nodes: Vec<_> = topsort.collect();
let flip = |(i, &n)| (n, i);
let topsort_ind = topsort_nodes.iter().enumerate().map(flip).collect();
let causal = vec![Causal::default(); topsort_nodes.len()];
let mut topsort_ind = UnmanagedDenseMap::with_capacity(graph.node_count());
for (i, &n) in topsort_nodes.iter().enumerate() {
topsort_ind.set(n, i);
}
let causal = CausalVec::new(topsort_nodes.len());
Self {
graph,
topsort_nodes,
Expand Down Expand Up @@ -80,30 +76,31 @@ where
/// in some topological order. In the worst case this will traverse every
/// node in the graph and can be improved on in the future.
pub fn is_node_convex(&mut self, nodes: impl IntoIterator<Item = NodeIndex>) -> bool {
let nodes: BTreeSet<_> = nodes.into_iter().map(|n| self.topsort_ind[&n]).collect();
let nodes: BTreeSet<_> = nodes.into_iter().map(|n| self.topsort_ind[n]).collect();
let min_ind = *nodes.first().unwrap();
let max_ind = *nodes.last().unwrap();
for ind in min_ind..=max_ind {
let n = self.topsort_nodes[ind];
let mut in_inds = {
let in_neighs = self.graph.input_neighbours(n);
in_neighs
.map(|n| self.topsort_ind[&n])
.map(|n| self.topsort_ind[n])
.filter(|&ind| ind >= min_ind)
};
if nodes.contains(&ind) {
if in_inds.any(|ind| self.causal[ind] == Causal::F) {
if in_inds.any(|ind| self.causal.get(ind) == Causal::Future) {
// There is a node in the past that is also in the future!
return false;
}
self.causal[ind] = Causal::P;
self.causal.set(ind, Causal::Past);
} else {
self.causal[ind] = match in_inds
.any(|ind| nodes.contains(&ind) || self.causal[ind] == Causal::F)
let ind_causal = match in_inds
.any(|ind| nodes.contains(&ind) || self.causal.get(ind) == Causal::Future)
{
true => Causal::F,
false => Causal::P,
true => Causal::Future,
false => Causal::Past,
};
self.causal.set(ind, ind_causal);
}
}
true
Expand Down Expand Up @@ -149,6 +146,49 @@ where
}
}

/// Whether a node is in the past or in the future of a subgraph.
#[derive(Default, Clone, Debug, PartialEq, Eq)]
enum Causal {
#[default]
Past,
Future,
}

/// A memory-efficient substitute for `Vec<Causal>`.
struct CausalVec(BitVec);

impl From<bool> for Causal {
fn from(b: bool) -> Self {
match b {
true => Self::Future,
false => Self::Past,
}
}
}

impl From<Causal> for bool {
fn from(c: Causal) -> Self {
match c {
Causal::Past => false,
Causal::Future => true,
}
}
}

impl CausalVec {
fn new(len: usize) -> Self {
Self(bitvec![0; len])
}

fn set(&mut self, index: usize, causal: Causal) {
self.0.set(index, causal.into());
}

fn get(&self, index: usize) -> Causal {
self.0[index].into()
}
}

#[cfg(test)]
mod tests {
use crate::{LinkMut, NodeIndex, PortGraph, PortMut, PortView};
Expand Down

0 comments on commit ad0f464

Please sign in to comment.