Skip to content

Commit

Permalink
feat: Add CircuitHistory
Browse files Browse the repository at this point in the history
  • Loading branch information
lmondada committed Feb 18, 2025
1 parent 6bf8e4d commit 8478612
Show file tree
Hide file tree
Showing 7 changed files with 1,052 additions and 36 deletions.
10 changes: 5 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ missing_docs = "warn"
[patch.crates-io]

# Uncomment to use unreleased versions of hugr
hugr = { git = "https://github.com/CQCL/hugr", rev = "0efc806f" }
hugr-core = { git = "https://github.com/CQCL/hugr", rev = "0efc806f" }
hugr-passes = { git = "https://github.com/CQCL/hugr", rev = "0efc806f" }
hugr-cli = { git = "https://github.com/CQCL/hugr", rev = "0efc806f" }
hugr-model = { git = "https://github.com/CQCL/hugr", rev = "0efc806f" }
hugr = { git = "https://github.com/CQCL/hugr", rev = "739cf944" }
hugr-core = { git = "https://github.com/CQCL/hugr", rev = "739cf944" }
hugr-passes = { git = "https://github.com/CQCL/hugr", rev = "739cf944" }
hugr-cli = { git = "https://github.com/CQCL/hugr", rev = "739cf944" }
hugr-model = { git = "https://github.com/CQCL/hugr", rev = "739cf944" }
# portgraph = { git = "https://github.com/CQCL/portgraph", rev = "68b96ac737e0c285d8c543b2d74a7aa80a18202c" }

[workspace.dependencies]
Expand Down
178 changes: 176 additions & 2 deletions tket2/src/diff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
//! Diffs can be created on top of existing diffs, resulting in an acyclic
//! history of circuit transformations.
pub mod experimental;
mod history;
mod replacement;

pub use history::CircuitHistory;

use std::{
cell::RefCell,
cmp::Ordering,
Expand All @@ -19,7 +23,9 @@ use std::{

use derive_more::{Display, Error, From};
use derive_where::derive_where;
use hugr::{hugr::SimpleReplacementError, Hugr, HugrView, IncomingPort, Node, Wire};
use hugr::{
hugr::SimpleReplacementError, Direction, Hugr, HugrView, IncomingPort, Node, Port, Wire,
};
use itertools::Itertools;
use relrc::RelRc;

Expand All @@ -38,7 +44,8 @@ use crate::{
/// Use [`CircuitDiff::try_from_circuit`] to create a new "root" diff, i.e.
/// without any parents, and use [`CircuitDiff::apply_replacement`] to create
/// new diffs as children of the current diff.
#[derive(Clone)]
#[derive(From)]
#[derive_where(Clone)]
pub struct CircuitDiff<H = Hugr>(RelRc<CircuitDiffData<H>, InvalidNodes>);

#[derive(Clone)]
Expand Down Expand Up @@ -152,6 +159,15 @@ pub struct WireEquivalence<H = Hugr> {
wire_to_children: RefCell<BTreeMap<Wire, BTreeSet<ChildWire<H>>>>,
}

type CircuitDiffPtr<H> = *const relrc::node::InnerData<CircuitDiffData<H>, InvalidNodes>;

impl<H> CircuitDiff<H> {
/// Get the pointer to the inner data of the diff
fn as_ptr(&self) -> CircuitDiffPtr<H> {
self.0.as_ptr()
}
}

impl<H: HugrView> CircuitDiff<H> {
/// Create a new circuit diff from a circuit
pub fn try_from_circuit(circuit: Circuit<H>) -> Result<Self, HashError> {
Expand All @@ -167,6 +183,11 @@ impl<H: HugrView> CircuitDiff<H> {
self.0.value().circuit.circuit()
}

/// Get the io nodes of the diff
pub fn io_nodes(&self) -> [Node; 2] {
self.as_circuit().io_nodes()
}

/// Get the diff circuit as a hugr
pub fn as_hugr(&self) -> &H {
self.as_circuit().hugr()
Expand Down Expand Up @@ -211,6 +232,106 @@ impl<H: HugrView> CircuitDiff<H> {

new
}

fn wire_to_children(&self, wire: Wire) -> Vec<(Self, Wire)> {
let mut w = self
.0
.value()
.equivalent_wires
.wire_to_children
.borrow_mut();
let Some(wire_to_children_mut) = w.get_mut(&wire) else {
return vec![];
};
let mut wire_to_children = Vec::new();
wire_to_children_mut.retain(|child_wire| {
// remove edges to no longer existing nodes
let Some(target) = child_wire.edge.target().upgrade() else {
return false;
};
wire_to_children.push((CircuitDiff(target), child_wire.wire));
true
});

wire_to_children
}

fn input_to_parent(&self, wire: Wire) -> Option<ParentWire> {
self.0
.value()
.equivalent_wires
.input_to_parent
.get(&wire)
.copied()
}

fn output_to_parent(&self, wire: Wire) -> Option<&BTreeSet<ParentWire>> {
self.0.value().equivalent_wires.output_to_parent.get(&wire)
}

fn get_parent(&self, parent_wire: &ParentWire) -> Self {
let edge = self
.0
.incoming(parent_wire.incoming_index)
.expect("invalid parent index");
CircuitDiff(edge.source().clone())
}

fn all_parents(&self) -> impl ExactSizeIterator<Item = Self> + '_ {
self.0.all_parents().cloned().map_into()
}

/// Get equivalent ports in children of a given port using the wire equivalences
fn equivalent_children_ports<'a>(
&'a self,
node: Node,
port: Port,
) -> impl Iterator<Item = Owned<H, (Node, Port)>> + 'a {
let Ok(wire) = port_to_wire(node, port, self.as_hugr()) else {
return None.into_iter().flatten();
};
let iter = self
.wire_to_children(wire)
.into_iter()
.flat_map(move |(child, wire)| {
let to_owned = |data| Owned {
owner: child.clone(),
data,
};
wire_to_ports(wire, port.direction(), child.as_hugr())
.map(to_owned)
.collect_vec()
});
Some(iter).into_iter().flatten()
}

/// Get equivalent ports in parents of a given port using the wire equivalences
// TODO: make stronger assumptions on the kinds of wires in `input_to_parent`
// and `output_to_parent` to make this more efficient
fn equivalent_parent_ports<'a>(
&'a self,
node: Node,
port: Port,
) -> impl Iterator<Item = Owned<H, (Node, Port)>> + 'a {
let Ok(wire) = port_to_wire(node, port, self.as_hugr()) else {
return None.into_iter().flatten();
};
let inputs = self.input_to_parent(wire).into_iter();
let outputs = self.output_to_parent(wire).into_iter().flatten().copied();
let iter = inputs.chain(outputs).flat_map(move |parent_wire| {
let parent = self.get_parent(&parent_wire);
let to_owned = |data| Owned {
owner: parent.clone(),
data,
};
wire_to_ports(parent_wire.wire, port.direction(), parent.as_hugr())
.map(to_owned)
.collect_vec()
});
Some(iter.unique_by(|o| (o.owner.as_ptr(), o.data)))
.into_iter()
.flatten()
}
}

impl<H> WireEquivalence<H> {
Expand Down Expand Up @@ -251,4 +372,57 @@ pub enum CircuitDiffError {
/// Error when a cycle is detected in the dfg
#[display("cycle detected in dfg")]
Cycle,
/// Error when merging two diffs
#[display("conflicting diffs")]
ConflictingDiffs,
/// Error when a history is empty
#[display("empty history")]
EmptyHistory,
/// Error when merging two diffs with different roots
#[display("distinct roots")]
DistinctRoots,
}

fn port_to_wire(
node: Node,
port: impl Into<Port>,
hugr: &impl HugrView,
) -> Result<Wire, CircuitDiffError> {
let port: Port = port.into();

use itertools::Either::{Left, Right};
match port.as_directed() {
Left(incoming) => {
let (node, outgoing) = hugr
.single_linked_output(node, incoming)
.ok_or(CircuitDiffError::NoUniqueOutput(node, incoming))?;
Ok(Wire::new(node, outgoing))
}
Right(outgoing) => Ok(Wire::new(node, outgoing)),
}
}

fn wire_to_ports(
wire: Wire,
dir: Direction,
hugr: &impl HugrView,
) -> impl Iterator<Item = (Node, Port)> + '_ {
use itertools::Either::{Left, Right};
let iter = match dir {
Direction::Incoming => Left(
hugr.linked_inputs(wire.node(), wire.source())
.map(|(node, port)| (node, port.into())),
),
Direction::Outgoing => Right([(wire.node(), wire.source().into())]),
};
iter.into_iter()
}

/// Data in a circuit diff, along with its owner [`CircuitDiff`]
#[derive_where(Clone; D)]
pub struct Owned<H, D> {
/// The owner of the data
pub owner: CircuitDiff<H>,
/// The data
pub data: D,
}
Loading

0 comments on commit 8478612

Please sign in to comment.