Skip to content

Commit

Permalink
Support run_sweep in GaugeTransformer. Pioneered in the CZ Gauge. (#6852
Browse files Browse the repository at this point in the history
)

Support as_sweep() in GaugeTransformer.

Enables output parameterized circuit with sweepable parameters for CZ Gauge.
  • Loading branch information
babacry authored Dec 21, 2024
1 parent 29d99d3 commit 686766f
Show file tree
Hide file tree
Showing 6 changed files with 432 additions and 19 deletions.
144 changes: 128 additions & 16 deletions cirq-core/cirq/transformers/gauge_compiling/cz_gauge.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,134 @@

CZGaugeSelector = GaugeSelector(
gauges=[
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.I, pre_q1=ops.I, post_q0=ops.I, post_q1=ops.I),
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.I, pre_q1=ops.X, post_q0=ops.Z, post_q1=ops.X),
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.I, pre_q1=ops.Y, post_q0=ops.Z, post_q1=ops.Y),
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.I, pre_q1=ops.Z, post_q0=ops.I, post_q1=ops.Z),
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.X, pre_q1=ops.I, post_q0=ops.X, post_q1=ops.Z),
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.X, pre_q1=ops.X, post_q0=ops.Y, post_q1=ops.Y),
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.X, pre_q1=ops.Y, post_q0=ops.Y, post_q1=ops.X),
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.X, pre_q1=ops.Z, post_q0=ops.X, post_q1=ops.I),
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Y, pre_q1=ops.I, post_q0=ops.Y, post_q1=ops.Z),
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Y, pre_q1=ops.X, post_q0=ops.X, post_q1=ops.Y),
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Y, pre_q1=ops.Y, post_q0=ops.X, post_q1=ops.X),
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Y, pre_q1=ops.Z, post_q0=ops.Y, post_q1=ops.I),
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Z, pre_q1=ops.I, post_q0=ops.Z, post_q1=ops.I),
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Z, pre_q1=ops.X, post_q0=ops.I, post_q1=ops.X),
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Z, pre_q1=ops.Y, post_q0=ops.I, post_q1=ops.Y),
ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Z, pre_q1=ops.Z, post_q0=ops.Z, post_q1=ops.Z),
ConstantGauge(
two_qubit_gate=CZ,
pre_q0=ops.I,
pre_q1=ops.I,
post_q0=ops.I,
post_q1=ops.I,
support_sweep=True,
),
ConstantGauge(
two_qubit_gate=CZ,
pre_q0=ops.I,
pre_q1=ops.X,
post_q0=ops.Z,
post_q1=ops.X,
support_sweep=True,
),
ConstantGauge(
two_qubit_gate=CZ,
pre_q0=ops.I,
pre_q1=ops.Y,
post_q0=ops.Z,
post_q1=ops.Y,
support_sweep=True,
),
ConstantGauge(
two_qubit_gate=CZ,
pre_q0=ops.I,
pre_q1=ops.Z,
post_q0=ops.I,
post_q1=ops.Z,
support_sweep=True,
),
ConstantGauge(
two_qubit_gate=CZ,
pre_q0=ops.X,
pre_q1=ops.I,
post_q0=ops.X,
post_q1=ops.Z,
support_sweep=True,
),
ConstantGauge(
two_qubit_gate=CZ,
pre_q0=ops.X,
pre_q1=ops.X,
post_q0=ops.Y,
post_q1=ops.Y,
support_sweep=True,
),
ConstantGauge(
two_qubit_gate=CZ,
pre_q0=ops.X,
pre_q1=ops.Y,
post_q0=ops.Y,
post_q1=ops.X,
support_sweep=True,
),
ConstantGauge(
two_qubit_gate=CZ,
pre_q0=ops.X,
pre_q1=ops.Z,
post_q0=ops.X,
post_q1=ops.I,
support_sweep=True,
),
ConstantGauge(
two_qubit_gate=CZ,
pre_q0=ops.Y,
pre_q1=ops.I,
post_q0=ops.Y,
post_q1=ops.Z,
support_sweep=True,
),
ConstantGauge(
two_qubit_gate=CZ,
pre_q0=ops.Y,
pre_q1=ops.X,
post_q0=ops.X,
post_q1=ops.Y,
support_sweep=True,
),
ConstantGauge(
two_qubit_gate=CZ,
pre_q0=ops.Y,
pre_q1=ops.Y,
post_q0=ops.X,
post_q1=ops.X,
support_sweep=True,
),
ConstantGauge(
two_qubit_gate=CZ,
pre_q0=ops.Y,
pre_q1=ops.Z,
post_q0=ops.Y,
post_q1=ops.I,
support_sweep=True,
),
ConstantGauge(
two_qubit_gate=CZ,
pre_q0=ops.Z,
pre_q1=ops.I,
post_q0=ops.Z,
post_q1=ops.I,
support_sweep=True,
),
ConstantGauge(
two_qubit_gate=CZ,
pre_q0=ops.Z,
pre_q1=ops.X,
post_q0=ops.I,
post_q1=ops.X,
support_sweep=True,
),
ConstantGauge(
two_qubit_gate=CZ,
pre_q0=ops.Z,
pre_q1=ops.Y,
post_q0=ops.I,
post_q1=ops.Y,
support_sweep=True,
),
ConstantGauge(
two_qubit_gate=CZ,
pre_q0=ops.Z,
pre_q1=ops.Z,
post_q0=ops.Z,
post_q1=ops.Z,
support_sweep=True,
),
]
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
class TestCZGauge(GaugeTester):
two_qubit_gate = cirq.CZ
gauge_transformer = CZGaugeTransformer
sweep_must_pass = True
153 changes: 151 additions & 2 deletions cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,24 @@

"""Creates the abstraction for gauge compiling as a cirq transformer."""

from typing import Callable, Tuple, Optional, Sequence, Union, List
from typing import Callable, Dict, Tuple, Optional, Sequence, Union, List
from itertools import count
from dataclasses import dataclass
import abc
import itertools
import functools
import sympy

from dataclasses import dataclass
from attrs import frozen, field
import numpy as np

from cirq.transformers import transformer_api
from cirq import ops, circuits
from cirq.study import sweepable
from cirq.protocols import unitary_protocol
from cirq.protocols.has_unitary_protocol import has_unitary
from cirq.study.sweeps import Points, Zip
from cirq.transformers.analytical_decompositions import single_qubit_decompositions


class Gauge(abc.ABC):
Expand Down Expand Up @@ -72,6 +79,7 @@ class ConstantGauge(Gauge):
default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g)
)
swap_qubits: bool = False
support_sweep: bool = False

def sample(self, gate: ops.Gate, prng: np.random.Generator) -> "ConstantGauge":
return self
Expand Down Expand Up @@ -201,6 +209,100 @@ def __call__(
new_moments.extend(_build_moments(right))
return circuits.Circuit.from_moments(*new_moments)

def as_sweep(
self,
circuit: circuits.AbstractCircuit,
*,
N: int,
context: Optional[transformer_api.TransformerContext] = None,
prng: Optional[np.random.Generator] = None,
) -> Tuple[circuits.AbstractCircuit, sweepable.Sweepable]:
"""Generates a parameterized circuit with *N* sets of sweepable parameters.
Args:
circuit: The input circuit to be processed by gauge compiling.
N: The number of parameter sets to generate.
context: A `cirq.TransformerContext` storing common configurable options for
the transformers.
prng: A pseudo-random number generator to select a gauge within a gauge cluster.
"""

rng = np.random.default_rng() if prng is None else prng
if context is None:
context = transformer_api.TransformerContext(deep=False)
if context.deep:
raise ValueError('GaugeTransformer cannot be used with deep=True')
new_moments: List[List[ops.Operation]] = [] # Store parameterized circuits.
values_by_params: Dict[str, List[float]] = {} # map from symbol name to N values.
symbol_count = count()
# Map from "((pre|post),$qid,$moment_id)" to gate parameters.
# E.g. {(post,q1,2): {"x_exponent": "x1", "z_exponent": "z1", "axis_phase": "a1"}}
symbols_by_loc: Dict[Tuple[str, ops.Qid, int], Dict[str, sympy.Symbol]] = {}

def single_qubit_next_symbol() -> Dict[str, sympy.Symbol]:
sid = next(symbol_count)
return _parameterize(1, sid)

# Build parameterized circuit.
for moment_id, moment in enumerate(circuit):
center_moment: List[ops.Operation] = []
left_moment: List[ops.Operation] = []
right_moment: List[ops.Operation] = []
for op in moment:
if isinstance(op, ops.TaggedOperation) and set(op.tags).intersection(
context.tags_to_ignore
):
center_moment.append(op)
continue
if op.gate is not None and op in self.target:
# Build symbols for the gauge, for a 2-qubit gauge, symbols will be built for
# pre/post q0/q1 and the new 2-qubit gate if the 2-qubit gate is updated in
# the gauge compiling.
center_moment.append(op)
for prefix, q in itertools.product(["pre", "post"], op.qubits):
xza_by_symbols = single_qubit_next_symbol() # xza in phased xz gate.
loc = (prefix, q, moment_id)
symbols_by_loc[loc] = xza_by_symbols
new_op = ops.PhasedXZGate(**xza_by_symbols).on(q)
for symbol in xza_by_symbols.values():
values_by_params.update({str(symbol): []})
if prefix == "pre":
left_moment.append(new_op)
else:
right_moment.append(new_op)
else:
center_moment.append(op)
new_moments.extend(
[moment for moment in [left_moment, center_moment, right_moment] if moment]
)

# Assign values for parameters via randomly chosen GaugeSelector.
for _ in range(N):
for moment_id, moment in enumerate(circuit):
for op in moment:
if isinstance(op, ops.TaggedOperation) and set(op.tags).intersection(
context.tags_to_ignore
):
continue
if op.gate is not None and len(op.qubits) == 2 and op in self.target:
gauge = self.gauge_selector(rng).sample(op.gate, rng)
if not gauge.support_sweep:
raise NotImplementedError(
f"as_sweep isn't supported for {gauge.two_qubit_gate} gauge"
)
# Get the params of pre/post q0/q1 gates.
for pre_or_post, idx in itertools.product(["pre", "post"], [0, 1]):
symbols = symbols_by_loc[(pre_or_post, op.qubits[idx], moment_id)]
gates = getattr(gauge, f"{pre_or_post}_q{idx}")
phxz_params = _gate_sequence_to_phxz_params(gates, symbols)
for key, value in phxz_params.items():
values_by_params[key].append(value)
sweeps: List[Points] = [
Points(key=key, points=values) for key, values in values_by_params.items()
]

return circuits.Circuit.from_moments(*new_moments), Zip(*sweeps)


def _build_moments(operation_by_qubits: List[List[ops.Operation]]) -> List[List[ops.Operation]]:
"""Builds moments from a list of operations grouped by qubits.
Expand All @@ -212,3 +314,50 @@ def _build_moments(operation_by_qubits: List[List[ops.Operation]]) -> List[List[
for moment in itertools.zip_longest(*operation_by_qubits):
moments.append([op for op in moment if op is not None])
return moments


def _parameterize(num_qubits: int, symbol_id: int) -> Dict[str, sympy.Symbol]:
"""Returns symbolized parameters for the gate."""

if num_qubits == 1: # Convert single qubit gate to parameterized PhasedXZGate.
phased_xz_params = {
"x_exponent": sympy.Symbol(f"x{symbol_id}"),
"z_exponent": sympy.Symbol(f"z{symbol_id}"),
"axis_phase_exponent": sympy.Symbol(f"a{symbol_id}"),
}
return phased_xz_params
raise NotImplementedError("parameterization for non single qubit gates is not supported yet")


def _gate_sequence_to_phxz_params(
gates: Tuple[ops.Gate, ...], xza_by_symbols: Dict[str, sympy.Symbol]
) -> Dict[str, float]:
for gate in gates:
if not has_unitary(gate) or gate.num_qubits() != 1:
raise ValueError(
"Invalid gate sequence to be converted to PhasedXZGate."
f"Found incompatiable gate {gate} in sequence."
)
phxz = (
single_qubit_decompositions.single_qubit_matrix_to_phxz(
functools.reduce(
np.matmul, [unitary_protocol.unitary(gate) for gate in reversed(gates)]
)
)
or ops.I
)
if phxz is ops.I: # Identity gate
return {
str(xza_by_symbols["x_exponent"]): 0.0,
str(xza_by_symbols["z_exponent"]): 0.0,
str(xza_by_symbols["axis_phase_exponent"]): 0.0,
}
# Check the gate type, needs to be a PhasedXZ gate.
if not isinstance(phxz, ops.PhasedXZGate):
raise ValueError("Failed to convert the gate sequence to a PhasedXZ gate.")
if phxz is not None:
return {
str(xza_by_symbols["x_exponent"]): phxz.x_exponent,
str(xza_by_symbols["z_exponent"]): phxz.z_exponent,
str(xza_by_symbols["axis_phase_exponent"]): phxz.axis_phase_exponent,
}
Loading

0 comments on commit 686766f

Please sign in to comment.