From a314f3f4969406d468021ebf055d6fc6df165cce Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Wed, 8 Nov 2023 10:42:59 +0000 Subject: [PATCH 1/4] feat: Rewrite tracing --- badger-optimiser/Cargo.toml | 2 +- badger-optimiser/src/main.rs | 12 ++- tket2/Cargo.toml | 10 ++- tket2/src/optimiser/badger.rs | 19 +++-- tket2/src/optimiser/badger/log.rs | 13 ++- tket2/src/rewrite.rs | 1 + tket2/src/rewrite/strategy.rs | 42 +++++++++- tket2/src/rewrite/trace.rs | 126 ++++++++++++++++++++++++++++++ 8 files changed, 212 insertions(+), 13 deletions(-) create mode 100644 tket2/src/rewrite/trace.rs diff --git a/badger-optimiser/Cargo.toml b/badger-optimiser/Cargo.toml index 45a50b44..b60ff20a 100644 --- a/badger-optimiser/Cargo.toml +++ b/badger-optimiser/Cargo.toml @@ -11,7 +11,7 @@ license-file = { workspace = true } [dependencies] clap = { version = "4.4.2", features = ["derive"] } serde_json = "1.0" -tket2 = { workspace = true, features = ["portmatching"] } +tket2 = { workspace = true, features = ["portmatching", "rewrite-tracing"] } quantinuum-hugr = { workspace = true } itertools = { workspace = true } tket-json-rs = { workspace = true } diff --git a/badger-optimiser/src/main.rs b/badger-optimiser/src/main.rs index 861f94af..b87ecaef 100644 --- a/badger-optimiser/src/main.rs +++ b/badger-optimiser/src/main.rs @@ -15,6 +15,7 @@ use tket2::json::{load_tk1_json_file, save_tk1_json_file}; use tket2::optimiser::badger::log::BadgerLogger; use tket2::optimiser::badger::BadgerOptions; use tket2::optimiser::{BadgerOptimiser, DefaultBadgerOptimiser}; +use tket2::rewrite::trace::RewriteTracer; #[cfg(feature = "peak_alloc")] #[global_allocator] @@ -104,6 +105,12 @@ struct CmdLineArgs { help = "The priority queue size. Defaults to 100." )] queue_size: usize, + /// Trace each rewrite applied to the circuit. + #[arg( + long = "rewrite-tracing", + help = "Trace each rewrite applied to the circuit. Prints statistics for the best circuit at the end of the optimisation." + )] + rewrite_tracing: bool, } fn main() -> Result<(), Box> { @@ -129,7 +136,10 @@ fn main() -> Result<(), Box> { let badger_logger = BadgerLogger::new(circ_candidates_csv); - let circ = load_tk1_json_file(input_path)?; + let mut circ = load_tk1_json_file(input_path)?; + if opts.rewrite_tracing { + circ.enable_rewrite_tracing(); + } println!("Loading optimiser..."); let Ok(optimiser) = load_optimiser(ecc_path) else { diff --git a/tket2/Cargo.toml b/tket2/Cargo.toml index 31f17864..fdc0189c 100644 --- a/tket2/Cargo.toml +++ b/tket2/Cargo.toml @@ -15,9 +15,17 @@ name = "tket2" path = "src/lib.rs" [features] +# Enables some python bindings pyo3 = ["dep:pyo3"] + +# Enables search and replace optimisation passes using the `portmatching` crate. portmatching = ["dep:portmatching", "dep:rmp-serde"] +# Stores a trace of the applied rewrites +rewrite-tracing = [] + +default = [] + [dependencies] lazy_static = "1.4.0" cgmath = "0.18.0" @@ -44,7 +52,7 @@ strum_macros = "0.25.2" strum = "0.25.0" fxhash = "0.2.1" rmp-serde = { version = "1.1.2", optional = true } -delegate = "0.10.0" +delegate = "0.11.0" csv = { version = "1.2.2" } chrono = { version = "0.4.30" } bytemuck = "1.14.0" diff --git a/tket2/src/optimiser/badger.rs b/tket2/src/optimiser/badger.rs index 687d2813..789237b9 100644 --- a/tket2/src/optimiser/badger.rs +++ b/tket2/src/optimiser/badger.rs @@ -37,6 +37,7 @@ use crate::optimiser::badger::hugr_pqueue::{Entry, HugrPQ}; use crate::optimiser::badger::worker::BadgerWorker; use crate::passes::CircuitChunks; use crate::rewrite::strategy::RewriteStrategy; +use crate::rewrite::trace::RewriteTracer; use crate::rewrite::Rewriter; use crate::Circuit; @@ -158,7 +159,8 @@ where let mut best_circ = circ.clone(); let mut best_circ_cost = self.cost(circ); - logger.log_best(&best_circ_cost); + let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len()); + logger.log_best(&best_circ_cost, num_rewrites); // Hash of seen circuits. Dot not store circuits as this map gets huge let hash = circ.circuit_hash().unwrap(); @@ -181,7 +183,8 @@ where if cost < best_circ_cost { best_circ = circ.clone(); best_circ_cost = cost.clone(); - logger.log_best(&best_circ_cost); + let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len()); + logger.log_best(&best_circ_cost, num_rewrites); last_best_time = Instant::now(); } circ_cnt += 1; @@ -297,7 +300,8 @@ where if cost < best_circ_cost { best_circ = circ; best_circ_cost = cost; - logger.log_best(&best_circ_cost); + let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len()); + logger.log_best(&best_circ_cost, num_rewrites); if let Some(t) = opt.progress_timeout { progress_timeout_event = crossbeam_channel::at(Instant::now() + Duration::from_secs(t)); } @@ -337,7 +341,8 @@ where if cost < best_circ_cost { best_circ = circ; best_circ_cost = cost; - logger.log_best(&best_circ_cost); + let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len()); + logger.log_best(&best_circ_cost, num_rewrites); } } PriorityChannelLog::CircuitCount { @@ -381,7 +386,8 @@ where let mut chunks = CircuitChunks::split_with_cost(circ, max_chunk_cost, |op| self.strategy.op_cost(op)); - logger.log_best(circ_cost.clone()); + let num_rewrites = circ.rewrite_trace().map(|rs| rs.len()); + logger.log_best(circ_cost.clone(), num_rewrites); let (joins, rx_work): (Vec<_>, Vec<_>) = chunks .iter_mut() @@ -420,7 +426,8 @@ where let best_circ = chunks.reassemble()?; let best_circ_cost = self.cost(&best_circ); if best_circ_cost.clone() < circ_cost { - logger.log_best(best_circ_cost.clone()); + let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len()); + logger.log_best(best_circ_cost.clone(), num_rewrites); } logger.log_processing_end(opt.n_threads.get(), None, best_circ_cost, true, false); diff --git a/tket2/src/optimiser/badger/log.rs b/tket2/src/optimiser/badger/log.rs index 3c01fdf6..e0fb5eea 100644 --- a/tket2/src/optimiser/badger/log.rs +++ b/tket2/src/optimiser/badger/log.rs @@ -49,8 +49,17 @@ impl<'w> BadgerLogger<'w> { /// Log a new best candidate #[inline] - pub fn log_best(&mut self, best_cost: C) { - self.log(format!("new best of size {:?}", best_cost)); + pub fn log_best( + &mut self, + best_cost: C, + num_rewrites: Option, + ) { + match num_rewrites { + Some(rs) => self.log(format!( + "new best of size {best_cost:?} after {rs} rewrites" + )), + None => self.log(format!("new best of size {:?}", best_cost)), + } if let Some(csv_writer) = self.circ_candidates_csv.as_mut() { csv_writer.serialize(BestCircSer::new(best_cost)).unwrap(); csv_writer.flush().unwrap(); diff --git a/tket2/src/rewrite.rs b/tket2/src/rewrite.rs index c8e44e85..6d51cc07 100644 --- a/tket2/src/rewrite.rs +++ b/tket2/src/rewrite.rs @@ -3,6 +3,7 @@ #[cfg(feature = "portmatching")] pub mod ecc_rewriter; pub mod strategy; +pub mod trace; use bytemuck::TransparentWrapper; #[cfg(feature = "portmatching")] diff --git a/tket2/src/rewrite/strategy.rs b/tket2/src/rewrite/strategy.rs index 456e21bf..de5796e8 100644 --- a/tket2/src/rewrite/strategy.rs +++ b/tket2/src/rewrite/strategy.rs @@ -28,6 +28,7 @@ use itertools::Itertools; use crate::circuit::cost::{is_cx, is_quantum, CircuitCost, CostDelta, MajorMinorCost}; use crate::Circuit; +use super::trace::{RewriteTrace, RewriteTracer}; use super::CircuitRewrite; /// Rewriting strategies for circuit optimisation. @@ -144,6 +145,7 @@ impl RewriteStrategy for GreedyRewriteStrategy { } changed_nodes.extend(rewrite.subcircuit().nodes().iter().copied()); cost_delta += rewrite.node_count_delta(); + circ.add_rewrite_trace(RewriteTrace::new(1)); rewrite .apply(&mut circ) .expect("Could not perform rewrite in greedy strategy"); @@ -219,6 +221,7 @@ impl RewriteStrategy for ExhaustiveGreedyStrategy { let mut curr_circ = circ.clone(); let mut changed_nodes = HashSet::new(); let mut cost_delta = Default::default(); + let mut composed_rewrite_count = 0; for (rewrite, delta) in &rewrites[i..] { if !changed_nodes.is_empty() && rewrite @@ -230,11 +233,15 @@ impl RewriteStrategy for ExhaustiveGreedyStrategy { changed_nodes.extend(rewrite.invalidation_set()); cost_delta += delta.clone(); + composed_rewrite_count += 1; + rewrite .clone() .apply(&mut curr_circ) .expect("Could not perform rewrite in exhaustive greedy strategy"); } + + curr_circ.add_rewrite_trace(RewriteTrace::new(composed_rewrite_count)); rewrite_sets.circs.push(curr_circ); rewrite_sets.cost_deltas.push(cost_delta); } @@ -285,6 +292,7 @@ impl RewriteStrategy for ExhaustiveThresholdStrategy { return None; } let mut circ = circ.clone(); + circ.add_rewrite_trace(RewriteTrace::new(1)); rw.apply(&mut circ).expect("invalid pattern match"); Some((circ, target_cost.sub_cost(&pattern_cost))) }) @@ -462,6 +470,7 @@ mod tests { use hugr::{Hugr, Node}; use itertools::Itertools; + use crate::rewrite::trace::REWRITE_TRACING_ENABLED; use crate::{ circuit::Circuit, rewrite::{CircuitRewrite, Subcircuit}, @@ -494,9 +503,16 @@ mod tests { #[test] fn test_greedy_strategy() { - let circ = n_cx(10); + let mut circ = n_cx(10); let cx_gates = circ.commands().map(|cmd| cmd.node()).collect_vec(); + assert_eq!(circ.rewrite_trace(), None); + circ.enable_rewrite_tracing(); + match REWRITE_TRACING_ENABLED { + true => assert_eq!(circ.rewrite_trace(), Some(vec![])), + false => assert_eq!(circ.rewrite_trace(), None), + } + let rws = [ rw_to_empty(&circ, cx_gates[0..2].to_vec()), rw_to_full(&circ, cx_gates[4..7].to_vec()), @@ -508,12 +524,17 @@ mod tests { let rewritten = strategy.apply_rewrites(rws, &circ); assert_eq!(rewritten.len(), 1); assert_eq!(rewritten.circs[0].num_gates(), 5); + + if REWRITE_TRACING_ENABLED { + assert_eq!(rewritten.circs[0].rewrite_trace().unwrap().len(), 3); + } } #[test] fn test_exhaustive_default_strategy() { - let circ = n_cx(10); + let mut circ = n_cx(10); let cx_gates = circ.commands().map(|cmd| cmd.node()).collect_vec(); + circ.enable_rewrite_tracing(); let rws = [ rw_to_empty(&circ, cx_gates[0..2].to_vec()), @@ -527,6 +548,23 @@ mod tests { let exp_circ_lens = HashSet::from_iter([3, 7, 9]); let circ_lens: HashSet<_> = rewritten.circs.iter().map(|c| c.num_gates()).collect(); assert_eq!(circ_lens, exp_circ_lens); + + if REWRITE_TRACING_ENABLED { + // Each strategy branch applies a single rewrite, composed of + // multiple individual elements from `rws`. + assert_eq!( + rewritten.circs[0].rewrite_trace().unwrap(), + vec![RewriteTrace::new(3)] + ); + assert_eq!( + rewritten.circs[1].rewrite_trace().unwrap(), + vec![RewriteTrace::new(2)] + ); + assert_eq!( + rewritten.circs[2].rewrite_trace().unwrap(), + vec![RewriteTrace::new(1)] + ); + } } #[test] diff --git a/tket2/src/rewrite/trace.rs b/tket2/src/rewrite/trace.rs new file mode 100644 index 00000000..5cd5cb90 --- /dev/null +++ b/tket2/src/rewrite/trace.rs @@ -0,0 +1,126 @@ +//! Utilities for tracing the rewrites applied to a circuit. +//! +//! This is only tracked if the `rewrite-tracing` feature is enabled. + +use hugr::hugr::hugrmut::HugrMut; +use hugr::hugr::NodeMetadata; +use itertools::Itertools; + +use crate::Circuit; + +use super::CircuitRewrite; + +/// Metadata key for the circuit rewrite trace. +pub const METADATA_REWRITES: &str = "TKET2.rewrites"; + +/// Global read-only flag for enabling rewrite tracing. +/// Enable it by setting the `rewrite-tracing` feature. +/// +/// Note that circuits must be explicitly enabled for rewrite tracing by calling +/// [`RewriteTracer::enable_rewrite_tracing`]. +pub const REWRITE_TRACING_ENABLED: bool = cfg!(feature = "rewrite-tracing"); + +/// The trace of a rewrite applied to a circuit. +/// +/// Traces are only enabled if the `rewrite-tracing` feature is enabled and +/// [`RewriteTracer::enable_rewrite_tracing`] is called on the circuit. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct RewriteTrace { + /// A count of the number of individual patterns matched for this rewrite step. + /// + /// This is relevant when using a greedy rewrite strategy. + individual_matches: u16, +} + +impl From<&CircuitRewrite> for RewriteTrace { + #[inline] + fn from(_rewrite: &CircuitRewrite) -> Self { + // NOTE: We don't currently track any actual information about the rewrite. + Self { + individual_matches: 1, + } + } +} + +impl RewriteTrace { + /// Create a new trace. + #[inline] + pub fn new(individual_matches: u16) -> Self { + Self { individual_matches } + } +} + +impl From<&serde_json::Value> for RewriteTrace { + #[inline] + fn from(value: &serde_json::Value) -> Self { + Self { + individual_matches: value.as_u64().unwrap() as u16, + } + } +} + +impl From for serde_json::Value { + #[inline] + fn from(trace: RewriteTrace) -> Self { + serde_json::Value::from(trace.individual_matches) + } +} + +/// Extension trait for circuits that can trace rewrites applied to them. +/// +/// This is only tracked if the `rewrite-tracing` feature is enabled and +/// `enable_rewrite_tracing` is called on the circuit before. +pub trait RewriteTracer: Circuit + HugrMut + Sized { + /// Enable rewrite tracing for the circuit. + #[inline] + fn enable_rewrite_tracing(&mut self) { + if !REWRITE_TRACING_ENABLED { + return; + } + let meta = self + .get_metadata_mut(self.root(), METADATA_REWRITES) + .unwrap(); + if *meta == NodeMetadata::Null { + *meta = NodeMetadata::Array(vec![]); + } + } + + /// Register a rewrite applied to the circuit. + /// + /// Returns `true` if the rewrite was successfully registered, or `false` if it was ignored. + #[inline] + fn add_rewrite_trace(&mut self, rewrite: impl Into) -> bool { + if !REWRITE_TRACING_ENABLED { + return false; + } + match self + .get_metadata_mut(self.root(), METADATA_REWRITES) + .ok() + .and_then(|m| m.as_array_mut()) + { + Some(meta) => { + let rewrite = rewrite.into(); + meta.push(rewrite.into()); + true + } + // Tracing was not enable for this circuit. + None => false, + } + } + + /// Returns the traces of rewrites applied to the circuit. + /// + /// Returns `None` if rewrite tracing is not enabled for this circuit. + #[inline] + fn rewrite_trace(&self) -> Option> { + if !REWRITE_TRACING_ENABLED { + return None; + } + let meta = self.get_metadata(self.root(), METADATA_REWRITES)?; + let rewrites = meta.as_array()?; + Some(rewrites.iter().map_into().collect_vec()) + } +} + +impl RewriteTracer for T {} From 6c9516c52fa82c30d4b371b98fb87dc68c075520 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 14 Dec 2023 10:43:01 +0000 Subject: [PATCH 2/4] Add new feature to the README --- README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index eff080f0..91375482 100644 --- a/README.md +++ b/README.md @@ -13,10 +13,13 @@ Version 2 of the TKET compiler. ## Features - `pyo3` -This optional feature enables some python bindings via pyo3. See the `tket2-py` folder for more. + Enables some python bindings via pyo3. See the `tket2-py` folder for more. - `portmatching` - This enables pattern matching using the `portmatching` crate. + Enables pattern matching using the `portmatching` crate. + +- `rewrite-tracing` + Adds opt-in tracking of the rewrites applied to a circuit. ## Developing TKET2 From 3092bc22f76806344d01e4c655a4cecac9930bf8 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 14 Dec 2023 10:50:48 +0000 Subject: [PATCH 3/4] Trace rewrites in `::apply` --- tket2/src/rewrite.rs | 19 +++++++++++++------ tket2/src/rewrite/strategy.rs | 4 +--- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tket2/src/rewrite.rs b/tket2/src/rewrite.rs index 6d51cc07..45dfafec 100644 --- a/tket2/src/rewrite.rs +++ b/tket2/src/rewrite.rs @@ -9,7 +9,6 @@ use bytemuck::TransparentWrapper; #[cfg(feature = "portmatching")] pub use ecc_rewriter::ECCRewriter; -use delegate::delegate; use derive_more::{From, Into}; use hugr::hugr::views::sibling_subgraph::{InvalidReplacement, InvalidSubgraph}; use hugr::Node; @@ -20,6 +19,8 @@ use hugr::{ use crate::circuit::Circuit; +use self::trace::RewriteTracer; + /// A subcircuit of a circuit. #[derive(Debug, Clone, From, Into)] #[repr(transparent)] @@ -108,11 +109,17 @@ impl CircuitRewrite { self.0.invalidation_set() } - delegate! { - to self.0 { - /// Apply the rewrite rule to a circuit. - pub fn apply(self, circ: &mut impl HugrMut) -> Result<(), SimpleReplacementError>; - } + /// Apply the rewrite rule to a circuit. + #[inline] + pub fn apply(self, circ: &mut impl HugrMut) -> Result<(), SimpleReplacementError> { + circ.add_rewrite_trace(&self); + self.0.apply(circ) + } + + /// Apply the rewrite rule to a circuit, without registering it in the rewrite trace. + #[inline] + pub fn apply_notrace(self, circ: &mut impl HugrMut) -> Result<(), SimpleReplacementError> { + self.0.apply(circ) } } diff --git a/tket2/src/rewrite/strategy.rs b/tket2/src/rewrite/strategy.rs index de5796e8..899b7909 100644 --- a/tket2/src/rewrite/strategy.rs +++ b/tket2/src/rewrite/strategy.rs @@ -145,7 +145,6 @@ impl RewriteStrategy for GreedyRewriteStrategy { } changed_nodes.extend(rewrite.subcircuit().nodes().iter().copied()); cost_delta += rewrite.node_count_delta(); - circ.add_rewrite_trace(RewriteTrace::new(1)); rewrite .apply(&mut circ) .expect("Could not perform rewrite in greedy strategy"); @@ -237,7 +236,7 @@ impl RewriteStrategy for ExhaustiveGreedyStrategy { rewrite .clone() - .apply(&mut curr_circ) + .apply_notrace(&mut curr_circ) .expect("Could not perform rewrite in exhaustive greedy strategy"); } @@ -292,7 +291,6 @@ impl RewriteStrategy for ExhaustiveThresholdStrategy { return None; } let mut circ = circ.clone(); - circ.add_rewrite_trace(RewriteTrace::new(1)); rw.apply(&mut circ).expect("invalid pattern match"); Some((circ, target_cost.sub_cost(&pattern_cost))) }) From e64af8328d294fec01510e2109c7e43795e90664 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 14 Dec 2023 17:17:53 +0000 Subject: [PATCH 4/4] Add rpitit TODO --- tket2/src/rewrite/trace.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tket2/src/rewrite/trace.rs b/tket2/src/rewrite/trace.rs index 5cd5cb90..4d0a4770 100644 --- a/tket2/src/rewrite/trace.rs +++ b/tket2/src/rewrite/trace.rs @@ -112,6 +112,8 @@ pub trait RewriteTracer: Circuit + HugrMut + Sized { /// Returns the traces of rewrites applied to the circuit. /// /// Returns `None` if rewrite tracing is not enabled for this circuit. + // + // TODO return an `impl Iterator` once rust 1.75 lands. #[inline] fn rewrite_trace(&self) -> Option> { if !REWRITE_TRACING_ENABLED {