Skip to content

Commit

Permalink
feat: Extract pytket parameters to input wires (#661)
Browse files Browse the repository at this point in the history
Depends on #660 
Closes #628
  • Loading branch information
aborgna-q authored Oct 15, 2024
1 parent d42842d commit f1d68bc
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 14 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ delegate = "0.13.1"
derive_more = "1.0.0"
downcast-rs = "1.2.0"
fxhash = "0.2.1"
indexmap = "2.6.0"
lazy_static = "1.5.0"
num-complex = "0.4"
num-rational = "0.4"
Expand Down
1 change: 1 addition & 0 deletions tket2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ portgraph = { workspace = true, features = ["serde"] }
strum_macros = { workspace = true }
strum = { workspace = true }
fxhash = { workspace = true }
indexmap = { workspace = true }
rmp-serde = { workspace = true, optional = true }
delegate = { workspace = true }
csv = { workspace = true }
Expand Down
2 changes: 2 additions & 0 deletions tket2/src/serialize/pytket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ const METADATA_B_REGISTERS: &str = "TKET1.bit_registers";
const METADATA_B_OUTPUT_REGISTERS: &str = "TKET1.bit_output_registers";
/// A tket1 operation "opgroup" field.
const METADATA_OPGROUP: &str = "TKET1.opgroup";
/// Explicit names for the input parameter wires.
const METADATA_INPUT_PARAMETERS: &str = "TKET1.input_parameters";

/// A serialized representation of a [`Circuit`].
///
Expand Down
32 changes: 22 additions & 10 deletions tket2/src/serialize/pytket/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use hugr::types::Signature;
use hugr::{Hugr, Wire};

use derive_more::Display;
use indexmap::IndexMap;
use itertools::{EitherOrBoth, Itertools};
use serde_json::json;
use tket_json_rs::circuit_json;
Expand All @@ -24,14 +25,15 @@ use super::{
METADATA_B_REGISTERS, METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_OUTPUT_REGISTERS,
METADATA_Q_REGISTERS,
};
use crate::extension::rotation::RotationOp;
use crate::extension::rotation::{RotationOp, ROTATION_TYPE};
use crate::extension::{REGISTRY, TKET1_EXTENSION_ID};
use crate::serialize::pytket::METADATA_INPUT_PARAMETERS;
use crate::symbolic_constant_op;

/// The state of an in-progress [`FunctionBuilder`] being built from a [`SerialCircuit`].
///
/// Mostly used to define helper internal methods.
#[derive(Debug, PartialEq)]
#[derive(Debug, Clone)]
pub(super) struct Tk1Decoder {
/// The Hugr being built.
pub hugr: FunctionBuilder<Hugr>,
Expand All @@ -41,6 +43,8 @@ pub(super) struct Tk1Decoder {
ordered_registers: Vec<RegisterHash>,
/// A set of registers that encode qubits.
qubit_registers: HashSet<RegisterHash>,
/// An ordered set of parameters found in operation arguments, and added as inputs.
parameters: IndexMap<String, LoadedParameter>,
}

impl Tk1Decoder {
Expand Down Expand Up @@ -116,6 +120,7 @@ impl Tk1Decoder {
register_wires,
ordered_registers,
qubit_registers,
parameters: IndexMap::new(),
})
}

Expand All @@ -132,6 +137,13 @@ impl Tk1Decoder {
"Some output wires were not associated with a register."
);

// Store the name for the input parameter wires
if !self.parameters.is_empty() {
let params = self.parameters.keys().cloned().collect_vec();
self.hugr
.set_metadata(METADATA_INPUT_PARAMETERS, json!(params));
}

self.hugr
.finish_hugr_with_outputs(outputs, &REGISTRY)
.unwrap()
Expand Down Expand Up @@ -264,6 +276,7 @@ impl Tk1Decoder {
fn load_parameter(&mut self, param: String) -> Wire {
fn process(
hugr: &mut FunctionBuilder<Hugr>,
input_params: &mut IndexMap<String, LoadedParameter>,
parsed: PytketParam,
param: &str,
) -> LoadedParameter {
Expand All @@ -286,18 +299,17 @@ impl Tk1Decoder {
let wire = hugr.add_load_const(value);
return LoadedParameter::float(wire);
}

// TODO: We need to add a `FunctionBuilder::add_input` function on the hugr side.
// Here we just add an opaque sympy box instead.
// https://github.com/CQCL/hugr/issues/1562
// https://github.com/CQCL/tket2/issues/628
process(hugr, PytketParam::Sympy(name), param)
// Look it up in the input parameters to the circuit, and add a new wire if needed.
*input_params.entry(name.to_string()).or_insert_with(|| {
let wire = hugr.add_input(ROTATION_TYPE);
LoadedParameter::rotation(wire)
})
}
PytketParam::Operation { op, args } => {
// We assume all operations take float inputs.
let input_wires = args
.into_iter()
.map(|arg| process(hugr, arg, param).as_float(hugr).wire)
.map(|arg| process(hugr, input_params, arg, param).as_float(hugr).wire)
.collect_vec();
let res = hugr.add_dataflow_op(op, input_wires).unwrap_or_else(|e| {
panic!("Error while decoding pytket operation parameter \"{param}\". {e}",)
Expand All @@ -309,7 +321,7 @@ impl Tk1Decoder {
}

let parsed = parse_pytket_param(&param);
process(&mut self.hugr, parsed, &param)
process(&mut self.hugr, &mut self.parameters, parsed, &param)
.as_rotation(&mut self.hugr)
.wire
}
Expand Down
16 changes: 13 additions & 3 deletions tket2/src/serialize/pytket/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ use super::op::Tk1Op;
use super::param::encode::fold_param_op;
use super::{
OpConvertError, TK1ConvertError, METADATA_B_OUTPUT_REGISTERS, METADATA_B_REGISTERS,
METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_OUTPUT_REGISTERS, METADATA_Q_REGISTERS,
METADATA_INPUT_PARAMETERS, METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_OUTPUT_REGISTERS,
METADATA_Q_REGISTERS,
};

/// The state of an in-progress [`SerialCircuit`] being built from a [`Circuit`].
Expand Down Expand Up @@ -558,8 +559,17 @@ impl ParameterTracker {
_ => None,
});

for (i, wire) in angle_input_wires.enumerate() {
tracker.add_parameter(wire, format!("f{i}"));
// The input parameter names may be specified in the metadata.
let fixed_input_names: Vec<String> = circ
.hugr()
.get_metadata(circ.parent(), METADATA_INPUT_PARAMETERS)
.and_then(|params| serde_json::from_value(params.clone()).ok())
.unwrap_or_default();
let extra_names = (fixed_input_names.len()..).map(|i| format!("f{i}"));
let mut param_name = fixed_input_names.into_iter().chain(extra_names);

for wire in angle_input_wires {
tracker.add_parameter(wire, param_name.next().unwrap());
}

tracker
Expand Down
28 changes: 27 additions & 1 deletion tket2/src/serialize/pytket/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use rstest::{fixture, rstest};
use tket_json_rs::circuit_json::{self, SerialCircuit};
use tket_json_rs::optype;

use super::{TKETDecode, METADATA_Q_OUTPUT_REGISTERS};
use super::{TKETDecode, METADATA_INPUT_PARAMETERS, METADATA_Q_OUTPUT_REGISTERS};
use crate::circuit::Circuit;
use crate::extension::rotation::{ConstRotation, RotationOp, ROTATION_TYPE};
use crate::extension::sympy::SympyOpDef;
Expand Down Expand Up @@ -158,6 +158,31 @@ fn circ_preset_qubits() -> Circuit {
hugr.into()
}

/// A simple circuit with some input parameters
#[fixture]
fn circ_parameterized() -> Circuit {
let input_t = vec![QB_T, ROTATION_TYPE, ROTATION_TYPE, ROTATION_TYPE];
let output_t = vec![QB_T];
let mut h = DFGBuilder::new(Signature::new(input_t, output_t)).unwrap();

let [q, r0, r1, r2] = h.input_wires_arr();

let [q] = h.add_dataflow_op(Tk2Op::Rx, [q, r0]).unwrap().outputs_arr();
let [q] = h.add_dataflow_op(Tk2Op::Ry, [q, r1]).unwrap().outputs_arr();
let [q] = h.add_dataflow_op(Tk2Op::Rz, [q, r2]).unwrap().outputs_arr();

let mut hugr = h.finish_hugr_with_outputs([q], &REGISTRY).unwrap();

// Preset names for some of the inputs
hugr.set_metadata(
hugr.root(),
METADATA_INPUT_PARAMETERS,
serde_json::json!(["alpha", "beta"]),
);

hugr.into()
}

/// A simple circuit with ancillae
#[fixture]
fn circ_measure_ancilla() -> Circuit {
Expand Down Expand Up @@ -318,6 +343,7 @@ fn json_file_roundtrip(#[case] circ: impl AsRef<std::path::Path>) {
#[rstest]
#[case::meas_ancilla(circ_measure_ancilla(), Signature::new_endo(vec![QB_T, QB_T, BOOL_T, BOOL_T]))]
#[case::preset_qubits(circ_preset_qubits(), Signature::new_endo(vec![QB_T, QB_T, QB_T]))]
#[case::preset_parameterized(circ_parameterized(), Signature::new(vec![QB_T, ROTATION_TYPE, ROTATION_TYPE, ROTATION_TYPE], vec![QB_T]))]
fn circuit_roundtrip(#[case] circ: Circuit, #[case] decoded_sig: Signature) {
let ser: SerialCircuit = SerialCircuit::encode(&circ).unwrap();
let deser: Circuit = ser.clone().decode().unwrap();
Expand Down

0 comments on commit f1d68bc

Please sign in to comment.