Skip to content

Commit

Permalink
fix: Track input linear units in Command (#310)
Browse files Browse the repository at this point in the history
`Command` only stored the linear units assigned to its output ports, so
querying the input linear unit to a `QFree` operation panicked.

This pr tracks the `input_linear_units` in addition to the output ones.

Fixes #309.
  • Loading branch information
aborgna-q authored Apr 16, 2024
1 parent 8576c49 commit 956eb05
Showing 1 changed file with 166 additions and 75 deletions.
241 changes: 166 additions & 75 deletions tket2/src/circuit/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use hugr::hugr::NodeType;
use hugr::ops::{OpTag, OpTrait};
use hugr::{IncomingPort, OutgoingPort};
use itertools::Either::{self, Left, Right};
use itertools::{EitherOrBoth, Itertools};
use petgraph::visit as pv;

use super::units::{filter, DefaultUnitLabeller, LinearUnit, UnitLabeller, Units};
Expand All @@ -25,11 +26,10 @@ pub struct Command<'circ, Circ> {
circ: &'circ Circ,
/// The operation node.
node: Node,
/// An assignment of linear units to the node's ports.
//
// We'll need something more complex if `follow_linear_port` stops being a
// direct map from input to output.
linear_units: Vec<LinearUnit>,
/// An assignment of linear units to the node's input ports.
input_linear_units: Vec<LinearUnit>,
/// An assignment of linear units to the node's output ports.
output_linear_units: Vec<LinearUnit>,
}

impl<'circ, Circ: Circuit> Command<'circ, Circ> {
Expand Down Expand Up @@ -165,7 +165,11 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> {
impl<'a, 'circ, Circ: Circuit> UnitLabeller for &'a Command<'circ, Circ> {
#[inline]
fn assign_linear(&self, _: Node, port: Port, _linear_count: usize) -> LinearUnit {
*self.linear_units.get(port.index()).unwrap_or_else(|| {
let units = match port.direction() {
Direction::Incoming => &self.input_linear_units,
Direction::Outgoing => &self.output_linear_units,
};
*units.get(port.index()).unwrap_or_else(|| {
panic!(
"Could not assign a linear unit to port {port:?} of node {:?}",
self.node
Expand All @@ -190,14 +194,17 @@ impl<'circ, Circ: Circuit> std::fmt::Debug for Command<'circ, Circ> {
f.debug_struct("Command")
.field("circuit name", &self.circ.name())
.field("node", &self.node)
.field("linear_units", &self.linear_units)
.field("input_linear_units", &self.input_linear_units)
.field("output_linear_units", &self.output_linear_units)
.finish()
}
}

impl<'circ, Circ> PartialEq for Command<'circ, Circ> {
fn eq(&self, other: &Self) -> bool {
self.node == other.node && self.linear_units == other.linear_units
self.node == other.node
&& self.input_linear_units == other.input_linear_units
&& self.output_linear_units == other.output_linear_units
}
}

Expand All @@ -208,29 +215,17 @@ impl<'circ, Circ> Clone for Command<'circ, Circ> {
Self {
circ: self.circ,
node: self.node,
linear_units: self.linear_units.clone(),
input_linear_units: self.input_linear_units.clone(),
output_linear_units: self.output_linear_units.clone(),
}
}
}

impl<'circ, Circ> std::hash::Hash for Command<'circ, Circ> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.node.hash(state);
self.linear_units.hash(state);
}
}

impl<'circ, Circ> PartialOrd for Command<'circ, Circ> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl<'circ, Circ> Ord for Command<'circ, Circ> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.node
.cmp(&other.node)
.then(self.linear_units.cmp(&other.linear_units))
self.input_linear_units.hash(state);
self.output_linear_units.hash(state);
}
}

Expand Down Expand Up @@ -356,12 +351,13 @@ where

/// Process a new node, updating wires in `unit_wires`.
///
/// Returns the an option with the `linear_units` used to construct a
/// [`Command`], if the node is not an input or output.
/// Returns the an option with the `input_linear_units` and
/// `output_linear_units` needed to construct a [`Command`], if the node is
/// not an input or output.
///
/// We don't return the command directly to avoid lifetime issues due to the
/// mutable borrow here.
fn process_node(&mut self, node: Node) -> Option<Vec<LinearUnit>> {
fn process_node(&mut self, node: Node) -> Option<(Vec<LinearUnit>, Vec<LinearUnit>)> {
// The root node is ignored.
if node == self.circ.root() {
return None;
Expand All @@ -373,56 +369,59 @@ where
return None;
}

// Collect the linear units passing through this command into the map
// Collect the linear units passing through this command into the maps
// required to construct a `Command`.
//
// Linear input ports are matched sequentially against the linear output
// ports, ignoring any non-linear ports when assigning unit ids. That
// is, the nth linear input is matched against the nth linear output,
// independently of whether there are any other ports mixed in.
//
// Updates the map tracking the last wire of linear units.
let linear_units: Vec<_> = Units::new_outgoing(self.circ, node, DefaultUnitLabeller)
.filter_map(filter::filter_linear)
.map(|(_, port, _)| {
// Find the linear unit id for this port.
let linear_id = self
.follow_linear_port(node, port)
.and_then(|input_port| {
let input_port = input_port.as_incoming().unwrap();
self.circ.linked_outputs(node, input_port).next()
})
.and_then(|(from, from_port)| {
// Remove the old wire from the map (if there was one)
self.wire_unit.remove(&Wire::new(from, from_port))
})
.unwrap_or({
// New linear unit found. Assign it a new id.
self.wire_unit.len()
});
// Update the map tracking the linear units
let new_wire = Wire::new(node, port);
self.wire_unit.insert(new_wire, linear_id);
LinearUnit::new(linear_id)
})
.collect();

Some(linear_units)
}

/// Returns the linear port on the node that corresponds to the same linear unit.
///
/// We assume the linear data uses the same port offsets on both sides of the node.
/// In the future we may want to have a more general mechanism to handle this.
//
// Note that `Command::linear_units` assumes this behaviour.
fn follow_linear_port(&self, node: Node, port: impl Into<Port>) -> Option<Port> {
let port = port.into();
let optype = self.circ.get_optype(node);
if !optype.port_kind(port)?.is_linear() {
return None;
}
let other_port = Port::new(port.direction().reverse(), port.index());
if optype.port_kind(other_port) == optype.port_kind(port) {
Some(other_port)
} else {
None
let mut input_linear_units = Vec::new();
let mut output_linear_units = Vec::new();

let input_units = Units::new_incoming(self.circ, node, DefaultUnitLabeller)
.filter_map(filter::filter_linear);
let output_units = Units::new_outgoing(self.circ, node, DefaultUnitLabeller)
.filter_map(filter::filter_linear);
for ports in input_units.zip_longest(output_units) {
// Terminate the input linear unit.
// Returns the linear id of the terminated unit.
let mut terminate_input =
|port: IncomingPort, wire_unit: &mut HashMap<Wire, usize>| -> Option<usize> {
let linear_id = self.circ.single_linked_output(node, port).and_then(
|(wire_node, wire_port)| wire_unit.remove(&Wire::new(wire_node, wire_port)),
)?;
input_linear_units.push(LinearUnit::new(linear_id));
Some(linear_id)
};

// Add a new linear unit for this output port.
let mut register_output =
|unit: usize, port: OutgoingPort, wire_unit: &mut HashMap<Wire, usize>| {
let wire = Wire::new(node, port);
wire_unit.insert(wire, unit);
output_linear_units.push(LinearUnit::new(unit));
};

match ports {
EitherOrBoth::Right((_, out_port, _)) => {
let new_id = self.wire_unit.len();
register_output(new_id, out_port, &mut self.wire_unit);
}
EitherOrBoth::Left((_, in_port, _)) => {
terminate_input(in_port, &mut self.wire_unit);
}
EitherOrBoth::Both((_, in_port, _), (_, out_port, _)) => {
if let Some(linear_id) = terminate_input(in_port, &mut self.wire_unit) {
register_output(linear_id, out_port, &mut self.wire_unit);
}
}
}
}

Some((input_linear_units, output_linear_units))
}
}

Expand All @@ -437,12 +436,13 @@ where
loop {
let node = self.next_node()?;
// Process the node, returning a command if it's not an input or output.
if let Some(linear_units) = self.process_node(node) {
if let Some((input_linear_units, output_linear_units)) = self.process_node(node) {
self.remaining -= 1;
return Some(Command {
circ: self.circ,
node,
linear_units,
input_linear_units,
output_linear_units,
});
}
}
Expand Down Expand Up @@ -476,7 +476,10 @@ mod test {
use hugr::std_extensions::arithmetic::float_types::ConstF64;
use hugr::types::FunctionType;
use itertools::Itertools;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};

use crate::extension::REGISTRY;
use crate::utils::build_simple_circuit;
use crate::Tk2Op;

Expand Down Expand Up @@ -602,4 +605,92 @@ mod test {
[CircuitUnit::Linear(0)],
);
}

/// Commands that allocate and free linear units.
///
/// Creates the following circuit:
/// ```plaintext
/// -------------[ ]---[QFree]
/// [CX]
/// [QAlloc]---[ ]-------------
/// ```
/// and checks that every command is correctly generated, and correctly
/// computes input/output units.
#[test]
fn alloc_free() -> Result<(), Box<dyn std::error::Error>> {
let qb_row = vec![QB_T; 1];
let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), qb_row))?;

let [q_in] = h.input_wires_arr();

let alloc = h.add_dataflow_op(Tk2Op::QAlloc, [])?;
let [q_new] = alloc.outputs_arr();

let cx = h.add_dataflow_op(Tk2Op::CX, [q_in, q_new])?;
let [q_in, q_new] = cx.outputs_arr();

let free = h.add_dataflow_op(Tk2Op::QFree, [q_in])?;

let circ = h.finish_hugr_with_outputs([q_new], &REGISTRY)?;

let mut cmds = circ.commands();

let alloc_cmd = cmds.next().unwrap();
assert_eq!(alloc_cmd.node(), alloc.node());
assert_eq!(
alloc_cmd.inputs().map(|(unit, _, _)| unit).collect_vec(),
[]
);
assert_eq!(
alloc_cmd.outputs().map(|(unit, _, _)| unit).collect_vec(),
[CircuitUnit::Linear(1)]
);

let cx_cmd = cmds.next().unwrap();
assert_eq!(cx_cmd.node(), cx.node());
assert_eq!(
cx_cmd.inputs().map(|(unit, _, _)| unit).collect_vec(),
[CircuitUnit::Linear(0), CircuitUnit::Linear(1)]
);
assert_eq!(
cx_cmd.outputs().map(|(unit, _, _)| unit).collect_vec(),
[CircuitUnit::Linear(0), CircuitUnit::Linear(1)]
);

let free_cmd = cmds.next().unwrap();
assert_eq!(free_cmd.node(), free.node());
assert_eq!(
free_cmd.inputs().map(|(unit, _, _)| unit).collect_vec(),
[CircuitUnit::Linear(0)]
);
assert_eq!(
free_cmd.outputs().map(|(unit, _, _)| unit).collect_vec(),
[]
);

Ok(())
}

/// Test the manual trait implementations of `Command`.
#[test]
fn test_impls() -> Result<(), Box<dyn std::error::Error>> {
let qb_row = vec![QB_T; 1];
let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), vec![]))?;
let [q_in] = h.input_wires_arr();
h.add_dataflow_op(Tk2Op::QFree, [q_in])?;
let circ = h.finish_hugr_with_outputs([], &REGISTRY)?;

let cmd1 = circ.commands().next().unwrap();
let cmd2 = circ.commands().next().unwrap();

assert_eq!(cmd1, cmd2);

let mut hasher1 = DefaultHasher::new();
cmd1.hash(&mut hasher1);
let mut hasher2 = DefaultHasher::new();
cmd2.hash(&mut hasher2);
assert_eq!(hasher1.finish(), hasher2.finish());

Ok(())
}
}

0 comments on commit 956eb05

Please sign in to comment.