From 567bb9b5186e078c532a85dde2e4c3b5a1b861f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Mon, 14 Oct 2024 15:33:15 +0100 Subject: [PATCH 1/2] feat: Extract pytket parameters to input wires --- Cargo.lock | 1 + Cargo.toml | 1 + tket2/Cargo.toml | 1 + tket2/src/serialize/pytket.rs | 2 ++ tket2/src/serialize/pytket/decoder.rs | 32 ++++++++++++++++++--------- tket2/src/serialize/pytket/encoder.rs | 16 +++++++++++--- 6 files changed, 40 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cd2c934d..8835bea8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2054,6 +2054,7 @@ dependencies = [ "hugr", "hugr-cli", "hugr-core", + "indexmap", "itertools 0.13.0", "lazy_static", "num-complex", diff --git a/Cargo.toml b/Cargo.toml index 1df56277..3691d43d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,7 @@ delegate = "0.13.1" derive_more = "1.0.0" downcast-rs = "1.2.0" fxhash = "0.2.1" +indexmap = "2.6.0" lazy_static = "1.5.0" num-complex = "0.4" num-rational = "0.4" diff --git a/tket2/Cargo.toml b/tket2/Cargo.toml index 4b1364d4..3274d210 100644 --- a/tket2/Cargo.toml +++ b/tket2/Cargo.toml @@ -61,6 +61,7 @@ portgraph = { workspace = true, features = ["serde"] } strum_macros = { workspace = true } strum = { workspace = true } fxhash = { workspace = true } +indexmap = { workspace = true } rmp-serde = { workspace = true, optional = true } delegate = { workspace = true } csv = { workspace = true } diff --git a/tket2/src/serialize/pytket.rs b/tket2/src/serialize/pytket.rs index 44e7e685..9013f77c 100644 --- a/tket2/src/serialize/pytket.rs +++ b/tket2/src/serialize/pytket.rs @@ -47,6 +47,8 @@ const METADATA_B_REGISTERS: &str = "TKET1.bit_registers"; const METADATA_B_OUTPUT_REGISTERS: &str = "TKET1.bit_output_registers"; /// A tket1 operation "opgroup" field. const METADATA_OPGROUP: &str = "TKET1.opgroup"; +/// Explicit names for the input parameter wires. +const METADATA_INPUT_PARAMETERS: &str = "TKET1.input_parameters"; /// A serialized representation of a [`Circuit`]. /// diff --git a/tket2/src/serialize/pytket/decoder.rs b/tket2/src/serialize/pytket/decoder.rs index 377c45ea..97cd7fc7 100644 --- a/tket2/src/serialize/pytket/decoder.rs +++ b/tket2/src/serialize/pytket/decoder.rs @@ -12,6 +12,7 @@ use hugr::types::Signature; use hugr::{Hugr, Wire}; use derive_more::Display; +use indexmap::IndexMap; use itertools::{EitherOrBoth, Itertools}; use serde_json::json; use tket_json_rs::circuit_json; @@ -24,14 +25,15 @@ use super::{ METADATA_B_REGISTERS, METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_OUTPUT_REGISTERS, METADATA_Q_REGISTERS, }; -use crate::extension::rotation::RotationOp; +use crate::extension::rotation::{RotationOp, ROTATION_TYPE}; use crate::extension::{REGISTRY, TKET1_EXTENSION_ID}; +use crate::serialize::pytket::METADATA_INPUT_PARAMETERS; use crate::symbolic_constant_op; /// The state of an in-progress [`FunctionBuilder`] being built from a [`SerialCircuit`]. /// /// Mostly used to define helper internal methods. -#[derive(Debug, PartialEq)] +#[derive(Debug, Clone)] pub(super) struct Tk1Decoder { /// The Hugr being built. pub hugr: FunctionBuilder, @@ -41,6 +43,8 @@ pub(super) struct Tk1Decoder { ordered_registers: Vec, /// A set of registers that encode qubits. qubit_registers: HashSet, + /// An ordered set of parameters found in operation arguments, and added as inputs. + parameters: IndexMap, } impl Tk1Decoder { @@ -116,6 +120,7 @@ impl Tk1Decoder { register_wires, ordered_registers, qubit_registers, + parameters: IndexMap::new(), }) } @@ -132,6 +137,13 @@ impl Tk1Decoder { "Some output wires were not associated with a register." ); + // Store the name for the input parameter wires + if !self.parameters.is_empty() { + let params = self.parameters.keys().cloned().collect_vec(); + self.hugr + .set_metadata(METADATA_INPUT_PARAMETERS, json!(params)); + } + self.hugr .finish_hugr_with_outputs(outputs, ®ISTRY) .unwrap() @@ -264,6 +276,7 @@ impl Tk1Decoder { fn load_parameter(&mut self, param: String) -> Wire { fn process( hugr: &mut FunctionBuilder, + input_params: &mut IndexMap, parsed: PytketParam, param: &str, ) -> LoadedParameter { @@ -286,18 +299,17 @@ impl Tk1Decoder { let wire = hugr.add_load_const(value); return LoadedParameter::float(wire); } - - // TODO: We need to add a `FunctionBuilder::add_input` function on the hugr side. - // Here we just add an opaque sympy box instead. - // https://github.com/CQCL/hugr/issues/1562 - // https://github.com/CQCL/tket2/issues/628 - process(hugr, PytketParam::Sympy(name), param) + // Look it up in the input parameters to the circuit, and add a new wire if needed. + *input_params.entry(name.to_string()).or_insert_with(|| { + let wire = hugr.add_input(ROTATION_TYPE); + LoadedParameter::rotation(wire) + }) } PytketParam::Operation { op, args } => { // We assume all operations take float inputs. let input_wires = args .into_iter() - .map(|arg| process(hugr, arg, param).as_float(hugr).wire) + .map(|arg| process(hugr, input_params, arg, param).as_float(hugr).wire) .collect_vec(); let res = hugr.add_dataflow_op(op, input_wires).unwrap_or_else(|e| { panic!("Error while decoding pytket operation parameter \"{param}\". {e}",) @@ -309,7 +321,7 @@ impl Tk1Decoder { } let parsed = parse_pytket_param(¶m); - process(&mut self.hugr, parsed, ¶m) + process(&mut self.hugr, &mut self.parameters, parsed, ¶m) .as_rotation(&mut self.hugr) .wire } diff --git a/tket2/src/serialize/pytket/encoder.rs b/tket2/src/serialize/pytket/encoder.rs index e50c3315..56133112 100644 --- a/tket2/src/serialize/pytket/encoder.rs +++ b/tket2/src/serialize/pytket/encoder.rs @@ -22,7 +22,8 @@ use super::op::Tk1Op; use super::param::encode::fold_param_op; use super::{ OpConvertError, TK1ConvertError, METADATA_B_OUTPUT_REGISTERS, METADATA_B_REGISTERS, - METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_OUTPUT_REGISTERS, METADATA_Q_REGISTERS, + METADATA_INPUT_PARAMETERS, METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_OUTPUT_REGISTERS, + METADATA_Q_REGISTERS, }; /// The state of an in-progress [`SerialCircuit`] being built from a [`Circuit`]. @@ -558,8 +559,17 @@ impl ParameterTracker { _ => None, }); - for (i, wire) in angle_input_wires.enumerate() { - tracker.add_parameter(wire, format!("f{i}")); + // The input parameter names may be specified in the metadata. + let fixed_input_names: Vec = circ + .hugr() + .get_metadata(circ.parent(), METADATA_INPUT_PARAMETERS) + .and_then(|params| serde_json::from_value(params.clone()).ok()) + .unwrap_or_default(); + let extra_names = (fixed_input_names.len()..).map(|i| format!("f{i}")); + let mut param_name = fixed_input_names.into_iter().chain(extra_names); + + for wire in angle_input_wires { + tracker.add_parameter(wire, param_name.next().unwrap()); } tracker From af2c458d4620aa7a204e24fe9e21f8a0512f6bcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Mon, 14 Oct 2024 16:55:25 +0100 Subject: [PATCH 2/2] test roundtrip with only some input parameter names --- tket2/src/serialize/pytket/tests.rs | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/tket2/src/serialize/pytket/tests.rs b/tket2/src/serialize/pytket/tests.rs index 65a35460..37d37f58 100644 --- a/tket2/src/serialize/pytket/tests.rs +++ b/tket2/src/serialize/pytket/tests.rs @@ -14,7 +14,7 @@ use rstest::{fixture, rstest}; use tket_json_rs::circuit_json::{self, SerialCircuit}; use tket_json_rs::optype; -use super::{TKETDecode, METADATA_Q_OUTPUT_REGISTERS}; +use super::{TKETDecode, METADATA_INPUT_PARAMETERS, METADATA_Q_OUTPUT_REGISTERS}; use crate::circuit::Circuit; use crate::extension::rotation::{ConstRotation, RotationOp, ROTATION_TYPE}; use crate::extension::sympy::SympyOpDef; @@ -158,6 +158,31 @@ fn circ_preset_qubits() -> Circuit { hugr.into() } +/// A simple circuit with some input parameters +#[fixture] +fn circ_parameterized() -> Circuit { + let input_t = vec![QB_T, ROTATION_TYPE, ROTATION_TYPE, ROTATION_TYPE]; + let output_t = vec![QB_T]; + let mut h = DFGBuilder::new(Signature::new(input_t, output_t)).unwrap(); + + let [q, r0, r1, r2] = h.input_wires_arr(); + + let [q] = h.add_dataflow_op(Tk2Op::Rx, [q, r0]).unwrap().outputs_arr(); + let [q] = h.add_dataflow_op(Tk2Op::Ry, [q, r1]).unwrap().outputs_arr(); + let [q] = h.add_dataflow_op(Tk2Op::Rz, [q, r2]).unwrap().outputs_arr(); + + let mut hugr = h.finish_hugr_with_outputs([q], ®ISTRY).unwrap(); + + // Preset names for some of the inputs + hugr.set_metadata( + hugr.root(), + METADATA_INPUT_PARAMETERS, + serde_json::json!(["alpha", "beta"]), + ); + + hugr.into() +} + /// A simple circuit with ancillae #[fixture] fn circ_measure_ancilla() -> Circuit { @@ -318,6 +343,7 @@ fn json_file_roundtrip(#[case] circ: impl AsRef) { #[rstest] #[case::meas_ancilla(circ_measure_ancilla(), Signature::new_endo(vec![QB_T, QB_T, BOOL_T, BOOL_T]))] #[case::preset_qubits(circ_preset_qubits(), Signature::new_endo(vec![QB_T, QB_T, QB_T]))] +#[case::preset_parameterized(circ_parameterized(), Signature::new(vec![QB_T, ROTATION_TYPE, ROTATION_TYPE, ROTATION_TYPE], vec![QB_T]))] fn circuit_roundtrip(#[case] circ: Circuit, #[case] decoded_sig: Signature) { let ser: SerialCircuit = SerialCircuit::encode(&circ).unwrap(); let deser: Circuit = ser.clone().decode().unwrap();