Skip to content

Commit

Permalink
feat: Add tket2.rotation.from_halfturns_unchecked op (#640)
Browse files Browse the repository at this point in the history
This is an operation that we are able to process while encoding pytket
circuits.
We expect guppy to emit this instead of adding explicit panics to unwrap
the fallible result.

The pytket encoder should now support float wires.

Closes #630
I'll open a followup issue on guppy
  • Loading branch information
aborgna-q authored Oct 8, 2024
1 parent 2cf79fd commit 86ffe64
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 26 deletions.
30 changes: 30 additions & 0 deletions tket2-py/tket2/extensions/_json_defs/tket2/rotation.json
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,36 @@
},
"binary": false
},
"from_halfturns_unchecked": {
"extension": "tket2.rotation",
"name": "from_halfturns_unchecked",
"description": "Construct rotation from number of half-turns (would be multiples of π in radians). Panics if the float is non-finite.",
"signature": {
"params": [],
"body": {
"input": [
{
"t": "Opaque",
"extension": "arithmetic.float.types",
"id": "float64",
"args": [],
"bound": "C"
}
],
"output": [
{
"t": "Opaque",
"extension": "tket2.rotation",
"id": "rotation",
"args": [],
"bound": "C"
}
],
"extension_reqs": []
}
},
"binary": false
},
"radd": {
"extension": "tket2.rotation",
"name": "radd",
Expand Down
28 changes: 23 additions & 5 deletions tket2/src/extension/rotation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,12 @@ impl CustomConst for ConstRotation {
#[non_exhaustive]
/// Rotation operations
pub enum RotationOp {
/// Construct rotation from number of half-turns (would be multiples of π in radians).
/// Construct rotation from a floating point number of half-turns (would be multiples of π in radians).
/// Returns an Option, failing when the input is NaN or infinite.
from_halfturns,
/// Construct rotation from a floating point number of half-turns (would be multiples of π in radians).
/// Panics if the input is NaN or infinite.
from_halfturns_unchecked,
/// Convert rotation to number of half-turns (would be multiples of π in radians).
to_halfturns,
/// Add two angles together (experimental, may be removed, use float addition
Expand All @@ -135,6 +139,9 @@ impl MakeOpDef for RotationOp {
type_row![FLOAT64_TYPE],
Type::from(option_type(type_row![ROTATION_TYPE])),
),
RotationOp::from_halfturns_unchecked => {
Signature::new(type_row![FLOAT64_TYPE], type_row![ROTATION_TYPE])
}
RotationOp::to_halfturns => {
Signature::new(type_row![ROTATION_TYPE], type_row![FLOAT64_TYPE])
}
Expand All @@ -151,6 +158,9 @@ impl MakeOpDef for RotationOp {
RotationOp::from_halfturns => {
"Construct rotation from number of half-turns (would be multiples of π in radians). Returns None if the float is non-finite."
}
RotationOp::from_halfturns_unchecked => {
"Construct rotation from number of half-turns (would be multiples of π in radians). Panics if the float is non-finite."
}
RotationOp::to_halfturns => {
"Convert rotation to number of half-turns (would be multiples of π in radians)."
}
Expand Down Expand Up @@ -193,14 +203,21 @@ pub(super) fn add_to_extension(extension: &mut Extension) {
/// An extension trait for [Dataflow] providing methods to add
/// "tket2.rotation" operations.
pub trait RotationOpBuilder: Dataflow {
/// Add a "tket2.rotation.fromturns" op.
/// Add a "tket2.rotation.from_halfturns" op.
fn add_from_halfturns(&mut self, turns: Wire) -> Result<Wire, BuildError> {
Ok(self
.add_dataflow_op(RotationOp::from_halfturns, [turns])?
.out_wire(0))
}

/// Add a "tket2.rotation.toturns" op.
/// Add a "tket2.rotation.from_halfturns_unchecked" op.
fn add_from_halfturns_unchecked(&mut self, turns: Wire) -> Result<Wire, BuildError> {
Ok(self
.add_dataflow_op(RotationOp::from_halfturns_unchecked, [turns])?
.out_wire(0))
}

/// Add a "tket2.rotation.to_halfturns" op.
fn add_to_halfturns(&mut self, rotation: Wire) -> Result<Wire, BuildError> {
Ok(self
.add_dataflow_op(RotationOp::to_halfturns, [rotation])?
Expand Down Expand Up @@ -262,15 +279,16 @@ mod test {
fn test_builder() {
let mut builder = DFGBuilder::new(Signature::new(
ROTATION_TYPE,
Type::from(option_type(ROTATION_TYPE)),
vec![Type::from(option_type(ROTATION_TYPE)), ROTATION_TYPE],
))
.unwrap();

let [rotation] = builder.input_wires_arr();
let turns = builder.add_to_halfturns(rotation).unwrap();
let mb_rotation = builder.add_from_halfturns(turns).unwrap();
let unwrapped_rotation = builder.add_from_halfturns_unchecked(turns).unwrap();
let _hugr = builder
.finish_hugr_with_outputs([mb_rotation], &REGISTRY)
.finish_hugr_with_outputs([mb_rotation, unwrapped_rotation], &REGISTRY)
.unwrap();
}

Expand Down
6 changes: 3 additions & 3 deletions tket2/src/serialize/pytket/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,8 +579,6 @@ impl ParameterTracker {
let input_count = if let Some(signature) = optype.dataflow_signature() {
// Only consider commands where all inputs and some outputs are
// parameters that we can track.
//
// TODO: We should track Option<T> parameters too, `RotationOp::from_halfturns` returns options.
const TRACKED_PARAMS: [Type; 2] = [ROTATION_TYPE, FLOAT64_TYPE];
let all_inputs = signature
.input()
Expand Down Expand Up @@ -691,7 +689,9 @@ impl ParameterTracker {
// Note that the tracked parameter strings are always written in half-turns,
// so the conversion here is a no-op.
RotationOp::to_halfturns => inputs[0].clone(),
RotationOp::from_halfturns => inputs[0].clone(),
RotationOp::from_halfturns_unchecked => inputs[0].clone(),
// The checked conversion returns an option, which we do not support.
RotationOp::from_halfturns => return None,
};
Some(s)
}
Expand Down
65 changes: 47 additions & 18 deletions tket2/src/serialize/pytket/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr};
use hugr::extension::prelude::{BOOL_T, QB_T};

use hugr::hugr::hugrmut::HugrMut;
use hugr::std_extensions::arithmetic::float_ops::FloatOps;
use hugr::types::Signature;
use hugr::HugrView;
use rstest::{fixture, rstest};
Expand Down Expand Up @@ -188,7 +189,7 @@ fn circ_measure_ancilla() -> Circuit {
}

#[fixture]
fn circ_add_angles_symbolic() -> Circuit {
fn circ_add_angles_symbolic() -> (Circuit, String) {
let input_t = vec![QB_T, ROTATION_TYPE, ROTATION_TYPE];
let output_t = vec![QB_T];
let mut h = DFGBuilder::new(Signature::new(input_t, output_t)).unwrap();
Expand All @@ -203,11 +204,12 @@ fn circ_add_angles_symbolic() -> Circuit {
.unwrap()
.outputs_arr();

h.finish_hugr_with_outputs([qb], &REGISTRY).unwrap().into()
let circ = h.finish_hugr_with_outputs([qb], &REGISTRY).unwrap().into();
(circ, "(f0 + f1)".to_string())
}

#[fixture]
fn circ_add_angles_constants() -> Circuit {
fn circ_add_angles_constants() -> (Circuit, String) {
let qb_row = vec![QB_T];
let mut h = DFGBuilder::new(Signature::new(qb_row.clone(), qb_row)).unwrap();

Expand All @@ -224,35 +226,60 @@ fn circ_add_angles_constants() -> Circuit {
.add_dataflow_op(Tk2Op::Rx, [qb, point5])
.unwrap()
.outputs();
h.finish_hugr_with_outputs(qbs, &REGISTRY).unwrap().into()
let circ = h.finish_hugr_with_outputs(qbs, &REGISTRY).unwrap().into();
(circ, "(0.2 + 0.3)".to_string())
}

#[fixture]
/// An Rx operation using some complex ops to compute its angle `cos(pi) + 1`.
fn circ_complex_angle_computation() -> Circuit {
let qb_row = vec![QB_T];
let mut h = DFGBuilder::new(Signature::new(qb_row.clone(), qb_row)).unwrap();
/// An Rx operation using some complex ops to compute its angle.
fn circ_complex_angle_computation() -> (Circuit, String) {
let input_t = vec![QB_T, ROTATION_TYPE, ROTATION_TYPE];
let output_t = vec![QB_T];
let mut h = DFGBuilder::new(Signature::new(input_t, output_t)).unwrap();

let qb = h.input_wires().next().unwrap();
let [qb, r0, r1] = h.input_wires_arr();

// Loading rotations and sympy expressions
let point2 = h.add_load_value(ConstRotation::new(0.2).unwrap());
let sympy = h
.add_dataflow_op(SympyOpDef.with_expr("cos(pi)".to_string()), [])
.unwrap()
.out_wire(0);
let final_rot = h
let added_rot = h
.add_dataflow_op(RotationOp::radd, [sympy, point2])
.unwrap()
.out_wire(0);

// TODO: Mix in some float ops. This requires unwrapping the result of `RotationOp::from_halfturns`.
// Float operations and conversions
let f0 = h
.add_dataflow_op(RotationOp::to_halfturns, [r0])
.unwrap()
.out_wire(0);
let f1 = h
.add_dataflow_op(RotationOp::to_halfturns, [r1])
.unwrap()
.out_wire(0);
let fpow = h
.add_dataflow_op(FloatOps::fpow, [f0, f1])
.unwrap()
.out_wire(0);
let rpow = h
.add_dataflow_op(RotationOp::from_halfturns_unchecked, [fpow])
.unwrap()
.out_wire(0);

let final_rot = h
.add_dataflow_op(RotationOp::radd, [rpow, added_rot])
.unwrap()
.out_wire(0);

let qbs = h
.add_dataflow_op(Tk2Op::Rx, [qb, final_rot])
.unwrap()
.outputs();

h.finish_hugr_with_outputs(qbs, &REGISTRY).unwrap().into()
let circ = h.finish_hugr_with_outputs(qbs, &REGISTRY).unwrap().into();
(circ, "((f0 ** f1) + (cos(pi) + 0.2))".to_string())
}

#[rstest]
Expand Down Expand Up @@ -318,14 +345,16 @@ fn circuit_roundtrip(#[case] circ: Circuit, #[case] decoded_sig: Signature) {
/// converted back to circuit inputs. This would require parsing symbolic
/// expressions.
#[rstest]
#[case::symbolic(circ_add_angles_symbolic(), "(f0 + f1)")]
#[case::constants(circ_add_angles_constants(), "(0.2 + 0.3)")]
#[case::complex(circ_complex_angle_computation(), "(cos(pi) + 0.2)")]
fn test_add_angle_serialise(#[case] circ_add_angles: Circuit, #[case] param_str: &str) {
let ser: SerialCircuit = SerialCircuit::encode(&circ_add_angles).unwrap();
#[case::symbolic(circ_add_angles_symbolic())]
#[case::constants(circ_add_angles_constants())]
#[case::complex(circ_complex_angle_computation())]
fn test_add_angle_serialise(#[case] circ_add_angles: (Circuit, String)) {
let (circ, expected) = circ_add_angles;

let ser: SerialCircuit = SerialCircuit::encode(&circ).unwrap();
assert_eq!(ser.commands.len(), 1);
assert_eq!(ser.commands[0].op.op_type, optype::OpType::Rx);
assert_eq!(ser.commands[0].op.params, Some(vec![param_str.into()]));
assert_eq!(ser.commands[0].op.params, Some(vec![expected]));

let deser: Circuit = ser.clone().decode().unwrap();
let reser = SerialCircuit::encode(&deser).unwrap();
Expand Down

0 comments on commit 86ffe64

Please sign in to comment.