Skip to content

Commit

Permalink
feat: implement dead code elimination
Browse files Browse the repository at this point in the history
  • Loading branch information
plafer committed Aug 15, 2024
1 parent 3b0314e commit f9c14cc
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 11 deletions.
147 changes: 147 additions & 0 deletions assembly/src/assembler/dead_code_elimination.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
use alloc::{
collections::{BTreeMap, BTreeSet},
vec::Vec,
};

use vm_core::mast::{MastForest, MastNode, MastNodeId};

/// Returns a `MastForest` where all nodes that are unreachable from all procedures are removed. It
/// also returns the map from old node IDs to new node IDs. Any [`MastNodeId`] used in reference to
/// the old [`MastForest`] should be remapped using this map.
pub fn dead_code_elimination(
mast_forest: MastForest,
) -> (MastForest, BTreeMap<MastNodeId, MastNodeId>) {
let live_ids = compute_live_ids(&mast_forest);

let (old_nodes, old_roots) = mast_forest.into_parts();
let (live_nodes, id_remappings) = prune_nodes(old_nodes, live_ids);

(build_pruned_mast_forest(live_nodes, old_roots, &id_remappings), id_remappings)
}

/// Compute all [`MastNodeId`]s that are "live"; that is, accessed by at least one procedure in the
/// MAST forest.
fn compute_live_ids(mast_forest: &MastForest) -> BTreeSet<MastNodeId> {
let mut live_ids = BTreeSet::new();

for &procedure_root in mast_forest.procedure_roots() {
compute_live_ids_for_node(procedure_root, mast_forest, &mut live_ids);
}

live_ids
}

fn compute_live_ids_for_node(
mast_node_id: MastNodeId,
mast_forest: &MastForest,
live_ids: &mut BTreeSet<MastNodeId>,
) {
live_ids.insert(mast_node_id);

match &mast_forest[mast_node_id] {
MastNode::Join(node) => {
compute_live_ids_for_node(node.first(), mast_forest, live_ids);
compute_live_ids_for_node(node.second(), mast_forest, live_ids);
},
MastNode::Split(node) => {
compute_live_ids_for_node(node.on_true(), mast_forest, live_ids);
compute_live_ids_for_node(node.on_false(), mast_forest, live_ids);
},
MastNode::Loop(node) => {
compute_live_ids_for_node(node.body(), mast_forest, live_ids);
},
MastNode::Call(node) => {
compute_live_ids_for_node(node.callee(), mast_forest, live_ids);
},
MastNode::Block(_) | MastNode::Dyn | MastNode::External(_) => (),
}
}

/// Returns the set of nodes that are live, as well as the mapping from "old ID" to "new ID" for all
/// live nodes.
fn prune_nodes(
mast_nodes: Vec<MastNode>,
live_ids: BTreeSet<MastNodeId>,
) -> (Vec<MastNode>, BTreeMap<MastNodeId, MastNodeId>) {
// Note: this allows us to safely use `usize as u32`, guaranteeing that it won't wrap around.
assert!(mast_nodes.len() < u32::MAX as usize);

let mut pruned_nodes = Vec::with_capacity(mast_nodes.len());
let mut id_remappings = BTreeMap::new();

for (old_node_index, old_node) in mast_nodes.into_iter().enumerate() {
let old_node_id: MastNodeId = (old_node_index as u32).into();

if live_ids.contains(&old_node_id) {
let new_node_id: MastNodeId = (pruned_nodes.len() as u32).into();
id_remappings.insert(old_node_id, new_node_id);

pruned_nodes.push(old_node);
}
}

(pruned_nodes, id_remappings)
}

/// Rewrites all [`MastNodeId`]s in the live nodes to the correct updated IDs using `id_remappings`,
/// which maps all old node IDs to new IDs.
fn build_pruned_mast_forest(
live_nodes: Vec<MastNode>,
old_root_ids: Vec<MastNodeId>,
id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
) -> MastForest {
let mut pruned_mast_forest = MastForest::new();

// Add each live node to the new MAST forest, making sure to rewrite any outdated internal
// `MastNodeId`s
for live_node in live_nodes {
match &live_node {
MastNode::Join(join_node) => {
let first_child =
id_remappings.get(&join_node.first()).copied().unwrap_or(join_node.first());
let second_child =
id_remappings.get(&join_node.second()).copied().unwrap_or(join_node.second());

pruned_mast_forest.add_join(first_child, second_child).unwrap();
},
MastNode::Split(split_node) => {
let on_true_child = id_remappings
.get(&split_node.on_true())
.copied()
.unwrap_or(split_node.on_true());
let on_false_child = id_remappings
.get(&split_node.on_false())
.copied()
.unwrap_or(split_node.on_false());

pruned_mast_forest.add_split(on_true_child, on_false_child).unwrap();
},
MastNode::Loop(loop_node) => {
let body_id =
id_remappings.get(&loop_node.body()).copied().unwrap_or(loop_node.body());

pruned_mast_forest.add_loop(body_id).unwrap();
},
MastNode::Call(call_node) => {
let callee_id =
id_remappings.get(&call_node.callee()).copied().unwrap_or(call_node.callee());

if call_node.is_syscall() {
pruned_mast_forest.add_syscall(callee_id).unwrap();
} else {
pruned_mast_forest.add_call(callee_id).unwrap();
}
},
MastNode::Block(_) | MastNode::Dyn | MastNode::External(_) => {
pruned_mast_forest.add_node(live_node).unwrap();
},
}
}

for old_root_id in old_root_ids {
let new_root_id = id_remappings.get(&old_root_id).copied().unwrap_or(old_root_id);
pruned_mast_forest.make_root(new_root_id);
}

pruned_mast_forest
}
24 changes: 13 additions & 11 deletions assembly/src/assembler/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use alloc::{collections::BTreeMap, sync::Arc, vec::Vec};

use dead_code_elimination::dead_code_elimination;
use mast_forest_builder::MastForestBuilder;
use module_graph::{ProcedureWrapper, WrappedModule};
use vm_core::{mast::MastNodeId, DecoratorList, Felt, Kernel, Operation, Program};
Expand All @@ -14,11 +15,13 @@ use crate::{
};

mod basic_block_builder;
mod dead_code_elimination;
mod id;
mod instruction;
mod mast_forest_builder;
mod module_graph;
mod procedure;

#[cfg(test)]
mod tests;

Expand Down Expand Up @@ -299,8 +302,8 @@ impl Assembler {
};

// TODO: show a warning if library exports are empty?

Ok(Library::new(mast_forest_builder.build(), exports))
let (mast_forest, _) = dead_code_elimination(mast_forest_builder.build());
Ok(Library::new(mast_forest, exports))
}

/// Assembles the provided module into a [KernelLibrary] intended to be used as a Kernel.
Expand Down Expand Up @@ -341,7 +344,8 @@ impl Assembler {

// TODO: show a warning if library exports are empty?

let library = Library::new(mast_forest_builder.build(), exports);
let (mast_forest, _) = dead_code_elimination(mast_forest_builder.build());
let library = Library::new(mast_forest, exports);
Ok(library.try_into()?)
}

Expand Down Expand Up @@ -381,9 +385,11 @@ impl Assembler {
.get_procedure(entrypoint)
.expect("compilation succeeded but root not found in cache");

let (mast_forest, id_remappings) = dead_code_elimination(mast_forest_builder.build());

Ok(Program::with_kernel(
mast_forest_builder.build(),
entry_procedure.body_node_id(),
mast_forest,
id_remappings[&entry_procedure.body_node_id()],
self.module_graph.kernel().clone(),
))
}
Expand Down Expand Up @@ -708,7 +714,7 @@ fn merge_contiguous_basic_blocks(
let mut contiguous_basic_block_ids: Vec<MastNodeId> = Vec::new();

for mast_node_id in mast_node_ids {
if mast_forest_builder.get_mast_node(mast_node_id).unwrap().is_basic_block() {
if mast_forest_builder[mast_node_id].is_basic_block() {
contiguous_basic_block_ids.push(mast_node_id);
} else {
if let Some(merged_basic_block_id) =
Expand Down Expand Up @@ -748,11 +754,7 @@ fn merge_basic_blocks(
for &basic_block_node_id in contiguous_basic_block_ids {
// It is safe to unwrap here, since we already checked that all IDs in
// `contiguous_basic_block_ids` are `BasicBlockNode`s
let basic_block_node = mast_forest_builder
.get_mast_node(basic_block_node_id)
.unwrap()
.get_basic_block()
.unwrap();
let basic_block_node = mast_forest_builder[basic_block_node_id].get_basic_block().unwrap();

for (op_idx, decorator) in basic_block_node.decorators() {
decorators.push((*op_idx + operations.len(), decorator.clone()));
Expand Down
3 changes: 3 additions & 0 deletions assembly/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ begin
basic_block mul add add add add add end
end";
assert_str_eq!(format!("{}", program), expected);

// Also ensure that dead code elimination works properly
assert_eq!(program.mast_forest().num_nodes(), 1);
Ok(())
}

Expand Down
20 changes: 20 additions & 0 deletions core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,20 @@ impl MastForest {
.try_into()
.expect("MAST forest contains more than 2^32 procedures.")
}

/// Returns the number of nodes in this MAST forest.
pub fn num_nodes(&self) -> u32 {
self.nodes.len() as u32
}
}

/// Destructors
impl MastForest {
pub fn into_parts(self) -> (Vec<MastNode>, Vec<MastNodeId>) {
let Self { nodes, roots } = self;

(nodes, roots)
}
}

impl Index<MastNodeId> for MastForest {
Expand Down Expand Up @@ -252,6 +266,12 @@ impl From<&MastNodeId> for u32 {
}
}

impl From<u32> for MastNodeId {
fn from(value: u32) -> Self {
Self(value)
}
}

impl fmt::Display for MastNodeId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "MastNodeId({})", self.0)
Expand Down

0 comments on commit f9c14cc

Please sign in to comment.