Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: avoid hugr clone in simple replace #1724

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 92 additions & 71 deletions hugr-core/src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
//! Implementation of the `SimpleReplace` operation.

use std::collections::{HashMap, HashSet};
use std::collections::HashMap;

use crate::hugr::hugrmut::InsertionResult;
pub use crate::hugr::internal::HugrMutInternals;
use crate::hugr::views::SiblingSubgraph;
use crate::hugr::{HugrMut, HugrView, Rewrite};
use crate::ops::{OpTag, OpTrait, OpType};
use crate::{Hugr, IncomingPort, Node, OutgoingPort};
use crate::{Hugr, IncomingPort, Node};
use thiserror::Error;

use super::inline_dfg::InlineDFGError;
Expand Down Expand Up @@ -67,23 +67,84 @@ impl Rewrite for SimpleReplacement {
}

fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
let parent = self.subgraph.get_parent(h);
let Self {
subgraph,
replacement,
nu_inp,
nu_out,
} = self;
let parent = subgraph.get_parent(h);
// 1. Check the parent node exists and is a DataflowParent.
if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
return Err(SimpleReplacementError::InvalidParentNode());
}
// 2. Check that all the to-be-removed nodes are children of it and are leaves.
for node in self.subgraph.nodes() {
for node in subgraph.nodes() {
if h.get_parent(*node) != Some(parent) || h.children(*node).next().is_some() {
return Err(SimpleReplacementError::InvalidRemovedNode());
}
}

let replacement_output_node = replacement
.get_io(replacement.root())
.expect("parent already checked.")[1];

// 3. Do the replacement.
// 3.1. Insert the replacement as a whole.
// Now we proceed to connect the edges between the newly inserted
// replacement and the rest of the graph.
//
// Existing connections to the removed subgraph will be automatically
// removed when the nodes are removed.

// 3.1. For each p = self.nu_inp[q] such that q is not an Output port, add an edge from the
// predecessor of p to (the new copy of) q.
let nu_inp_connects: Vec<_> = nu_inp
.iter()
.filter(|&((rep_inp_node, _), _)| {
replacement.get_optype(*rep_inp_node).tag() != OpTag::Output
})
.map(
|((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port))| {
// add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port)
let (rem_inp_pred_node, rem_inp_pred_port) = h
.single_linked_output(*rem_inp_node, *rem_inp_port)
.unwrap();
(
rem_inp_pred_node,
rem_inp_pred_port,
// the new input node will be updated after insertion
rep_inp_node,
rep_inp_port,
)
},
)
.collect();

// 3.2. For each q = self.nu_out[p] such that the predecessor of q is not an Input port, add an
// edge from (the new copy of) the predecessor of q to p.
let nu_out_connects: Vec<_> = nu_out
.iter()
.filter_map(|((rem_out_node, rem_out_port), rep_out_port)| {
let (rep_out_pred_node, rep_out_pred_port) = replacement
.single_linked_output(replacement_output_node, *rep_out_port)
.unwrap();
(replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input).then_some({
(
// the new output node will be updated after insertion
rep_out_pred_node,
rep_out_pred_port,
rem_out_node,
rem_out_port,
)
})
})
.collect();

// 3.3. Insert the replacement as a whole.
let InsertionResult {
new_root,
node_map: index_map,
} = h.insert_hugr(parent, self.replacement.clone());
} = h.insert_hugr(parent, replacement);

// remove the Input and Output nodes from the replacement graph
let replace_children = h.children(new_root).collect::<Vec<Node>>();
Expand All @@ -97,87 +158,47 @@ impl Rewrite for SimpleReplacement {
// remove the replacement root (which now has no children and no edges)
h.remove_node(new_root);

// Now we proceed to connect the edges between the newly inserted
// replacement and the rest of the graph.
//
// We delay creating these connections to avoid them getting mixed with
// the pre-existing ones in the following logic.
//
// Existing connections to the removed subgraph will be automatically
// removed when the nodes are removed.
let mut connect: HashSet<(Node, OutgoingPort, Node, IncomingPort)> = HashSet::new();

// 3.2. For each p = self.nu_inp[q] such that q is not an Output port, add an edge from the
// predecessor of p to (the new copy of) q.
for ((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port)) in &self.nu_inp {
if self.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output {
// add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port)
let (rem_inp_pred_node, rem_inp_pred_port) = h
.single_linked_output(*rem_inp_node, *rem_inp_port)
.unwrap();
let new_inp_node = index_map.get(rep_inp_node).unwrap();
connect.insert((
rem_inp_pred_node,
rem_inp_pred_port,
*new_inp_node,
*rep_inp_port,
));
}
// 3.4. Update replacement nodes according to insertion mapping and connect
for (src_node, src_port, tgt_node, tgt_port) in nu_inp_connects {
h.connect(
src_node,
src_port,
*index_map.get(tgt_node).unwrap(),
*tgt_port,
)
}
let replacement_output_node = self
.replacement
.get_io(self.replacement.root())
.expect("parent already checked.")[1];
// 3.3. For each q = self.nu_out[p] such that the predecessor of q is not an Input port, add an
// edge from (the new copy of) the predecessor of q to p.
for ((rem_out_node, rem_out_port), rep_out_port) in &self.nu_out {
let (rep_out_pred_node, rep_out_pred_port) = self
.replacement
.single_linked_output(replacement_output_node, *rep_out_port)
.unwrap();
if self.replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input {
let new_out_node = index_map.get(&rep_out_pred_node).unwrap();
connect.insert((
*new_out_node,
rep_out_pred_port,
*rem_out_node,
*rem_out_port,
));
}

for (src_node, src_port, tgt_node, tgt_port) in nu_out_connects {
h.connect(
*index_map.get(&src_node).unwrap(),
src_port,
*tgt_node,
*tgt_port,
)
}
// 3.4. For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0
// 3.5. For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0
// to p1.
//
// i.e. the replacement graph has direct edges between the input and output nodes.
for ((rem_out_node, rem_out_port), &rep_out_port) in &self.nu_out {
let rem_inp_nodeport = self.nu_inp.get(&(replacement_output_node, rep_out_port));
for ((rem_out_node, rem_out_port), &rep_out_port) in &nu_out {
let rem_inp_nodeport = nu_inp.get(&(replacement_output_node, rep_out_port));
if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport {
// add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port):
let (rem_inp_pred_node, rem_inp_pred_port) = h
.single_linked_output(*rem_inp_node, *rem_inp_port)
.unwrap();
// Delay connecting the nodes until after processing all nu_out
// entries.
//
// Otherwise, we might disconnect other wires in `rem_inp_node`
// that are needed for the following iterations.
connect.insert((

h.connect(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment above is also no longer true - but it is not clear to me how any disconnects might happen, and all tests (including those involving direct wires from input to output in the replacement) work

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was meant to fix #1323. See #1324 / test_half_nots.
I'm not sure why that it's not a problem anymore, I'd have to look at the logic here for a bit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that that PR did two things: stop disconnecting eagerly and stop connecting eagerly.

I have re-added connecting eagerly, but not disconnecting. Based on the tests, that seems fine.

rem_inp_pred_node,
rem_inp_pred_port,
*rem_out_node,
*rem_out_port,
));
);
}
}
connect
.into_iter()
.for_each(|(src_node, src_port, tgt_node, tgt_port)| {
h.connect(src_node, src_port, tgt_node, tgt_port);
});

// 3.5. Remove all nodes in self.removal and edges between them.
Ok(self
.subgraph

// 3.6. Remove all nodes in subgraph and edges between them.
Ok(subgraph
.nodes()
.iter()
.map(|&node| (node, h.remove_node(node)))
Expand Down
Loading