Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Extract pytket parameters to input wires #661

Merged
merged 2 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ delegate = "0.13.1"
derive_more = "1.0.0"
downcast-rs = "1.2.0"
fxhash = "0.2.1"
indexmap = "2.6.0"
lazy_static = "1.5.0"
num-complex = "0.4"
num-rational = "0.4"
Expand Down
1 change: 1 addition & 0 deletions tket2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ portgraph = { workspace = true, features = ["serde"] }
strum_macros = { workspace = true }
strum = { workspace = true }
fxhash = { workspace = true }
indexmap = { workspace = true }
rmp-serde = { workspace = true, optional = true }
delegate = { workspace = true }
csv = { workspace = true }
Expand Down
2 changes: 2 additions & 0 deletions tket2/src/serialize/pytket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ const METADATA_B_REGISTERS: &str = "TKET1.bit_registers";
const METADATA_B_OUTPUT_REGISTERS: &str = "TKET1.bit_output_registers";
/// A tket1 operation "opgroup" field.
const METADATA_OPGROUP: &str = "TKET1.opgroup";
/// Explicit names for the input parameter wires.
const METADATA_INPUT_PARAMETERS: &str = "TKET1.input_parameters";

/// A serialized representation of a [`Circuit`].
///
Expand Down
32 changes: 22 additions & 10 deletions tket2/src/serialize/pytket/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use hugr::types::Signature;
use hugr::{Hugr, Wire};

use derive_more::Display;
use indexmap::IndexMap;
use itertools::{EitherOrBoth, Itertools};
use serde_json::json;
use tket_json_rs::circuit_json;
Expand All @@ -24,14 +25,15 @@ use super::{
METADATA_B_REGISTERS, METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_OUTPUT_REGISTERS,
METADATA_Q_REGISTERS,
};
use crate::extension::rotation::RotationOp;
use crate::extension::rotation::{RotationOp, ROTATION_TYPE};
use crate::extension::{REGISTRY, TKET1_EXTENSION_ID};
use crate::serialize::pytket::METADATA_INPUT_PARAMETERS;
use crate::symbolic_constant_op;

/// The state of an in-progress [`FunctionBuilder`] being built from a [`SerialCircuit`].
///
/// Mostly used to define helper internal methods.
#[derive(Debug, PartialEq)]
#[derive(Debug, Clone)]
pub(super) struct Tk1Decoder {
/// The Hugr being built.
pub hugr: FunctionBuilder<Hugr>,
Expand All @@ -41,6 +43,8 @@ pub(super) struct Tk1Decoder {
ordered_registers: Vec<RegisterHash>,
/// A set of registers that encode qubits.
qubit_registers: HashSet<RegisterHash>,
/// An ordered set of parameters found in operation arguments, and added as inputs.
parameters: IndexMap<String, LoadedParameter>,
}

impl Tk1Decoder {
Expand Down Expand Up @@ -116,6 +120,7 @@ impl Tk1Decoder {
register_wires,
ordered_registers,
qubit_registers,
parameters: IndexMap::new(),
})
}

Expand All @@ -132,6 +137,13 @@ impl Tk1Decoder {
"Some output wires were not associated with a register."
);

// Store the name for the input parameter wires
if !self.parameters.is_empty() {
let params = self.parameters.keys().cloned().collect_vec();
self.hugr
.set_metadata(METADATA_INPUT_PARAMETERS, json!(params));
}

self.hugr
.finish_hugr_with_outputs(outputs, &REGISTRY)
.unwrap()
Expand Down Expand Up @@ -264,6 +276,7 @@ impl Tk1Decoder {
fn load_parameter(&mut self, param: String) -> Wire {
fn process(
hugr: &mut FunctionBuilder<Hugr>,
input_params: &mut IndexMap<String, LoadedParameter>,
parsed: PytketParam,
param: &str,
) -> LoadedParameter {
Expand All @@ -286,18 +299,17 @@ impl Tk1Decoder {
let wire = hugr.add_load_const(value);
return LoadedParameter::float(wire);
}

// TODO: We need to add a `FunctionBuilder::add_input` function on the hugr side.
// Here we just add an opaque sympy box instead.
// https://github.com/CQCL/hugr/issues/1562
// https://github.com/CQCL/tket2/issues/628
process(hugr, PytketParam::Sympy(name), param)
// Look it up in the input parameters to the circuit, and add a new wire if needed.
*input_params.entry(name.to_string()).or_insert_with(|| {
let wire = hugr.add_input(ROTATION_TYPE);
LoadedParameter::rotation(wire)
})
}
PytketParam::Operation { op, args } => {
// We assume all operations take float inputs.
let input_wires = args
.into_iter()
.map(|arg| process(hugr, arg, param).as_float(hugr).wire)
.map(|arg| process(hugr, input_params, arg, param).as_float(hugr).wire)
.collect_vec();
let res = hugr.add_dataflow_op(op, input_wires).unwrap_or_else(|e| {
panic!("Error while decoding pytket operation parameter \"{param}\". {e}",)
Expand All @@ -309,7 +321,7 @@ impl Tk1Decoder {
}

let parsed = parse_pytket_param(&param);
process(&mut self.hugr, parsed, &param)
process(&mut self.hugr, &mut self.parameters, parsed, &param)
.as_rotation(&mut self.hugr)
.wire
}
Expand Down
16 changes: 13 additions & 3 deletions tket2/src/serialize/pytket/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ use super::op::Tk1Op;
use super::param::encode::fold_param_op;
use super::{
OpConvertError, TK1ConvertError, METADATA_B_OUTPUT_REGISTERS, METADATA_B_REGISTERS,
METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_OUTPUT_REGISTERS, METADATA_Q_REGISTERS,
METADATA_INPUT_PARAMETERS, METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_OUTPUT_REGISTERS,
METADATA_Q_REGISTERS,
};

/// The state of an in-progress [`SerialCircuit`] being built from a [`Circuit`].
Expand Down Expand Up @@ -558,8 +559,17 @@ impl ParameterTracker {
_ => None,
});

for (i, wire) in angle_input_wires.enumerate() {
tracker.add_parameter(wire, format!("f{i}"));
// The input parameter names may be specified in the metadata.
let fixed_input_names: Vec<String> = circ
.hugr()
.get_metadata(circ.parent(), METADATA_INPUT_PARAMETERS)
.and_then(|params| serde_json::from_value(params.clone()).ok())
.unwrap_or_default();
let extra_names = (fixed_input_names.len()..).map(|i| format!("f{i}"));
let mut param_name = fixed_input_names.into_iter().chain(extra_names);

for wire in angle_input_wires {
tracker.add_parameter(wire, param_name.next().unwrap());
}

tracker
Expand Down
28 changes: 27 additions & 1 deletion tket2/src/serialize/pytket/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use rstest::{fixture, rstest};
use tket_json_rs::circuit_json::{self, SerialCircuit};
use tket_json_rs::optype;

use super::{TKETDecode, METADATA_Q_OUTPUT_REGISTERS};
use super::{TKETDecode, METADATA_INPUT_PARAMETERS, METADATA_Q_OUTPUT_REGISTERS};
use crate::circuit::Circuit;
use crate::extension::rotation::{ConstRotation, RotationOp, ROTATION_TYPE};
use crate::extension::sympy::SympyOpDef;
Expand Down Expand Up @@ -158,6 +158,31 @@ fn circ_preset_qubits() -> Circuit {
hugr.into()
}

/// A simple circuit with some input parameters
#[fixture]
fn circ_parameterized() -> Circuit {
let input_t = vec![QB_T, ROTATION_TYPE, ROTATION_TYPE, ROTATION_TYPE];
let output_t = vec![QB_T];
let mut h = DFGBuilder::new(Signature::new(input_t, output_t)).unwrap();

let [q, r0, r1, r2] = h.input_wires_arr();

let [q] = h.add_dataflow_op(Tk2Op::Rx, [q, r0]).unwrap().outputs_arr();
let [q] = h.add_dataflow_op(Tk2Op::Ry, [q, r1]).unwrap().outputs_arr();
let [q] = h.add_dataflow_op(Tk2Op::Rz, [q, r2]).unwrap().outputs_arr();

let mut hugr = h.finish_hugr_with_outputs([q], &REGISTRY).unwrap();

// Preset names for some of the inputs
hugr.set_metadata(
hugr.root(),
METADATA_INPUT_PARAMETERS,
serde_json::json!(["alpha", "beta"]),
);

hugr.into()
}

/// A simple circuit with ancillae
#[fixture]
fn circ_measure_ancilla() -> Circuit {
Expand Down Expand Up @@ -318,6 +343,7 @@ fn json_file_roundtrip(#[case] circ: impl AsRef<std::path::Path>) {
#[rstest]
#[case::meas_ancilla(circ_measure_ancilla(), Signature::new_endo(vec![QB_T, QB_T, BOOL_T, BOOL_T]))]
#[case::preset_qubits(circ_preset_qubits(), Signature::new_endo(vec![QB_T, QB_T, QB_T]))]
#[case::preset_parameterized(circ_parameterized(), Signature::new(vec![QB_T, ROTATION_TYPE, ROTATION_TYPE, ROTATION_TYPE], vec![QB_T]))]
fn circuit_roundtrip(#[case] circ: Circuit, #[case] decoded_sig: Signature) {
let ser: SerialCircuit = SerialCircuit::encode(&circ).unwrap();
let deser: Circuit = ser.clone().decode().unwrap();
Expand Down
Loading