From 956eb054d3e1cda2ae07fd50fb288042e26ad819 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 16 Apr 2024 13:26:47 +0200 Subject: [PATCH] fix: Track input linear units in `Command` (#310) `Command` only stored the linear units assigned to its output ports, so querying the input linear unit to a `QFree` operation panicked. This pr tracks the `input_linear_units` in addition to the output ones. Fixes #309. --- tket2/src/circuit/command.rs | 241 ++++++++++++++++++++++++----------- 1 file changed, 166 insertions(+), 75 deletions(-) diff --git a/tket2/src/circuit/command.rs b/tket2/src/circuit/command.rs index 24ed5ff9..0c369933 100644 --- a/tket2/src/circuit/command.rs +++ b/tket2/src/circuit/command.rs @@ -10,6 +10,7 @@ use hugr::hugr::NodeType; use hugr::ops::{OpTag, OpTrait}; use hugr::{IncomingPort, OutgoingPort}; use itertools::Either::{self, Left, Right}; +use itertools::{EitherOrBoth, Itertools}; use petgraph::visit as pv; use super::units::{filter, DefaultUnitLabeller, LinearUnit, UnitLabeller, Units}; @@ -25,11 +26,10 @@ pub struct Command<'circ, Circ> { circ: &'circ Circ, /// The operation node. node: Node, - /// An assignment of linear units to the node's ports. - // - // We'll need something more complex if `follow_linear_port` stops being a - // direct map from input to output. - linear_units: Vec, + /// An assignment of linear units to the node's input ports. + input_linear_units: Vec, + /// An assignment of linear units to the node's output ports. + output_linear_units: Vec, } impl<'circ, Circ: Circuit> Command<'circ, Circ> { @@ -165,7 +165,11 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> { impl<'a, 'circ, Circ: Circuit> UnitLabeller for &'a Command<'circ, Circ> { #[inline] fn assign_linear(&self, _: Node, port: Port, _linear_count: usize) -> LinearUnit { - *self.linear_units.get(port.index()).unwrap_or_else(|| { + let units = match port.direction() { + Direction::Incoming => &self.input_linear_units, + Direction::Outgoing => &self.output_linear_units, + }; + *units.get(port.index()).unwrap_or_else(|| { panic!( "Could not assign a linear unit to port {port:?} of node {:?}", self.node @@ -190,14 +194,17 @@ impl<'circ, Circ: Circuit> std::fmt::Debug for Command<'circ, Circ> { f.debug_struct("Command") .field("circuit name", &self.circ.name()) .field("node", &self.node) - .field("linear_units", &self.linear_units) + .field("input_linear_units", &self.input_linear_units) + .field("output_linear_units", &self.output_linear_units) .finish() } } impl<'circ, Circ> PartialEq for Command<'circ, Circ> { fn eq(&self, other: &Self) -> bool { - self.node == other.node && self.linear_units == other.linear_units + self.node == other.node + && self.input_linear_units == other.input_linear_units + && self.output_linear_units == other.output_linear_units } } @@ -208,7 +215,8 @@ impl<'circ, Circ> Clone for Command<'circ, Circ> { Self { circ: self.circ, node: self.node, - linear_units: self.linear_units.clone(), + input_linear_units: self.input_linear_units.clone(), + output_linear_units: self.output_linear_units.clone(), } } } @@ -216,21 +224,8 @@ impl<'circ, Circ> Clone for Command<'circ, Circ> { impl<'circ, Circ> std::hash::Hash for Command<'circ, Circ> { fn hash(&self, state: &mut H) { self.node.hash(state); - self.linear_units.hash(state); - } -} - -impl<'circ, Circ> PartialOrd for Command<'circ, Circ> { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl<'circ, Circ> Ord for Command<'circ, Circ> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.node - .cmp(&other.node) - .then(self.linear_units.cmp(&other.linear_units)) + self.input_linear_units.hash(state); + self.output_linear_units.hash(state); } } @@ -356,12 +351,13 @@ where /// Process a new node, updating wires in `unit_wires`. /// - /// Returns the an option with the `linear_units` used to construct a - /// [`Command`], if the node is not an input or output. + /// Returns the an option with the `input_linear_units` and + /// `output_linear_units` needed to construct a [`Command`], if the node is + /// not an input or output. /// /// We don't return the command directly to avoid lifetime issues due to the /// mutable borrow here. - fn process_node(&mut self, node: Node) -> Option> { + fn process_node(&mut self, node: Node) -> Option<(Vec, Vec)> { // The root node is ignored. if node == self.circ.root() { return None; @@ -373,56 +369,59 @@ where return None; } - // Collect the linear units passing through this command into the map + // Collect the linear units passing through this command into the maps // required to construct a `Command`. // + // Linear input ports are matched sequentially against the linear output + // ports, ignoring any non-linear ports when assigning unit ids. That + // is, the nth linear input is matched against the nth linear output, + // independently of whether there are any other ports mixed in. + // // Updates the map tracking the last wire of linear units. - let linear_units: Vec<_> = Units::new_outgoing(self.circ, node, DefaultUnitLabeller) - .filter_map(filter::filter_linear) - .map(|(_, port, _)| { - // Find the linear unit id for this port. - let linear_id = self - .follow_linear_port(node, port) - .and_then(|input_port| { - let input_port = input_port.as_incoming().unwrap(); - self.circ.linked_outputs(node, input_port).next() - }) - .and_then(|(from, from_port)| { - // Remove the old wire from the map (if there was one) - self.wire_unit.remove(&Wire::new(from, from_port)) - }) - .unwrap_or({ - // New linear unit found. Assign it a new id. - self.wire_unit.len() - }); - // Update the map tracking the linear units - let new_wire = Wire::new(node, port); - self.wire_unit.insert(new_wire, linear_id); - LinearUnit::new(linear_id) - }) - .collect(); - - Some(linear_units) - } - - /// Returns the linear port on the node that corresponds to the same linear unit. - /// - /// We assume the linear data uses the same port offsets on both sides of the node. - /// In the future we may want to have a more general mechanism to handle this. - // - // Note that `Command::linear_units` assumes this behaviour. - fn follow_linear_port(&self, node: Node, port: impl Into) -> Option { - let port = port.into(); - let optype = self.circ.get_optype(node); - if !optype.port_kind(port)?.is_linear() { - return None; - } - let other_port = Port::new(port.direction().reverse(), port.index()); - if optype.port_kind(other_port) == optype.port_kind(port) { - Some(other_port) - } else { - None + let mut input_linear_units = Vec::new(); + let mut output_linear_units = Vec::new(); + + let input_units = Units::new_incoming(self.circ, node, DefaultUnitLabeller) + .filter_map(filter::filter_linear); + let output_units = Units::new_outgoing(self.circ, node, DefaultUnitLabeller) + .filter_map(filter::filter_linear); + for ports in input_units.zip_longest(output_units) { + // Terminate the input linear unit. + // Returns the linear id of the terminated unit. + let mut terminate_input = + |port: IncomingPort, wire_unit: &mut HashMap| -> Option { + let linear_id = self.circ.single_linked_output(node, port).and_then( + |(wire_node, wire_port)| wire_unit.remove(&Wire::new(wire_node, wire_port)), + )?; + input_linear_units.push(LinearUnit::new(linear_id)); + Some(linear_id) + }; + + // Add a new linear unit for this output port. + let mut register_output = + |unit: usize, port: OutgoingPort, wire_unit: &mut HashMap| { + let wire = Wire::new(node, port); + wire_unit.insert(wire, unit); + output_linear_units.push(LinearUnit::new(unit)); + }; + + match ports { + EitherOrBoth::Right((_, out_port, _)) => { + let new_id = self.wire_unit.len(); + register_output(new_id, out_port, &mut self.wire_unit); + } + EitherOrBoth::Left((_, in_port, _)) => { + terminate_input(in_port, &mut self.wire_unit); + } + EitherOrBoth::Both((_, in_port, _), (_, out_port, _)) => { + if let Some(linear_id) = terminate_input(in_port, &mut self.wire_unit) { + register_output(linear_id, out_port, &mut self.wire_unit); + } + } + } } + + Some((input_linear_units, output_linear_units)) } } @@ -437,12 +436,13 @@ where loop { let node = self.next_node()?; // Process the node, returning a command if it's not an input or output. - if let Some(linear_units) = self.process_node(node) { + if let Some((input_linear_units, output_linear_units)) = self.process_node(node) { self.remaining -= 1; return Some(Command { circ: self.circ, node, - linear_units, + input_linear_units, + output_linear_units, }); } } @@ -476,7 +476,10 @@ mod test { use hugr::std_extensions::arithmetic::float_types::ConstF64; use hugr::types::FunctionType; use itertools::Itertools; + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + use crate::extension::REGISTRY; use crate::utils::build_simple_circuit; use crate::Tk2Op; @@ -602,4 +605,92 @@ mod test { [CircuitUnit::Linear(0)], ); } + + /// Commands that allocate and free linear units. + /// + /// Creates the following circuit: + /// ```plaintext + /// -------------[ ]---[QFree] + /// [CX] + /// [QAlloc]---[ ]------------- + /// ``` + /// and checks that every command is correctly generated, and correctly + /// computes input/output units. + #[test] + fn alloc_free() -> Result<(), Box> { + let qb_row = vec![QB_T; 1]; + let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), qb_row))?; + + let [q_in] = h.input_wires_arr(); + + let alloc = h.add_dataflow_op(Tk2Op::QAlloc, [])?; + let [q_new] = alloc.outputs_arr(); + + let cx = h.add_dataflow_op(Tk2Op::CX, [q_in, q_new])?; + let [q_in, q_new] = cx.outputs_arr(); + + let free = h.add_dataflow_op(Tk2Op::QFree, [q_in])?; + + let circ = h.finish_hugr_with_outputs([q_new], ®ISTRY)?; + + let mut cmds = circ.commands(); + + let alloc_cmd = cmds.next().unwrap(); + assert_eq!(alloc_cmd.node(), alloc.node()); + assert_eq!( + alloc_cmd.inputs().map(|(unit, _, _)| unit).collect_vec(), + [] + ); + assert_eq!( + alloc_cmd.outputs().map(|(unit, _, _)| unit).collect_vec(), + [CircuitUnit::Linear(1)] + ); + + let cx_cmd = cmds.next().unwrap(); + assert_eq!(cx_cmd.node(), cx.node()); + assert_eq!( + cx_cmd.inputs().map(|(unit, _, _)| unit).collect_vec(), + [CircuitUnit::Linear(0), CircuitUnit::Linear(1)] + ); + assert_eq!( + cx_cmd.outputs().map(|(unit, _, _)| unit).collect_vec(), + [CircuitUnit::Linear(0), CircuitUnit::Linear(1)] + ); + + let free_cmd = cmds.next().unwrap(); + assert_eq!(free_cmd.node(), free.node()); + assert_eq!( + free_cmd.inputs().map(|(unit, _, _)| unit).collect_vec(), + [CircuitUnit::Linear(0)] + ); + assert_eq!( + free_cmd.outputs().map(|(unit, _, _)| unit).collect_vec(), + [] + ); + + Ok(()) + } + + /// Test the manual trait implementations of `Command`. + #[test] + fn test_impls() -> Result<(), Box> { + let qb_row = vec![QB_T; 1]; + let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), vec![]))?; + let [q_in] = h.input_wires_arr(); + h.add_dataflow_op(Tk2Op::QFree, [q_in])?; + let circ = h.finish_hugr_with_outputs([], ®ISTRY)?; + + let cmd1 = circ.commands().next().unwrap(); + let cmd2 = circ.commands().next().unwrap(); + + assert_eq!(cmd1, cmd2); + + let mut hasher1 = DefaultHasher::new(); + cmd1.hash(&mut hasher1); + let mut hasher2 = DefaultHasher::new(); + cmd2.hash(&mut hasher2); + assert_eq!(hasher1.finish(), hasher2.finish()); + + Ok(()) + } }