diff --git a/cirq-core/cirq/transformers/gauge_compiling/cz_gauge.py b/cirq-core/cirq/transformers/gauge_compiling/cz_gauge.py index 226cb1af2cb..02012f1a82d 100644 --- a/cirq-core/cirq/transformers/gauge_compiling/cz_gauge.py +++ b/cirq-core/cirq/transformers/gauge_compiling/cz_gauge.py @@ -24,22 +24,134 @@ CZGaugeSelector = GaugeSelector( gauges=[ - ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.I, pre_q1=ops.I, post_q0=ops.I, post_q1=ops.I), - ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.I, pre_q1=ops.X, post_q0=ops.Z, post_q1=ops.X), - ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.I, pre_q1=ops.Y, post_q0=ops.Z, post_q1=ops.Y), - ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.I, pre_q1=ops.Z, post_q0=ops.I, post_q1=ops.Z), - ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.X, pre_q1=ops.I, post_q0=ops.X, post_q1=ops.Z), - ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.X, pre_q1=ops.X, post_q0=ops.Y, post_q1=ops.Y), - ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.X, pre_q1=ops.Y, post_q0=ops.Y, post_q1=ops.X), - ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.X, pre_q1=ops.Z, post_q0=ops.X, post_q1=ops.I), - ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Y, pre_q1=ops.I, post_q0=ops.Y, post_q1=ops.Z), - ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Y, pre_q1=ops.X, post_q0=ops.X, post_q1=ops.Y), - ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Y, pre_q1=ops.Y, post_q0=ops.X, post_q1=ops.X), - ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Y, pre_q1=ops.Z, post_q0=ops.Y, post_q1=ops.I), - ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Z, pre_q1=ops.I, post_q0=ops.Z, post_q1=ops.I), - ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Z, pre_q1=ops.X, post_q0=ops.I, post_q1=ops.X), - ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Z, pre_q1=ops.Y, post_q0=ops.I, post_q1=ops.Y), - ConstantGauge(two_qubit_gate=CZ, pre_q0=ops.Z, pre_q1=ops.Z, post_q0=ops.Z, post_q1=ops.Z), + ConstantGauge( + two_qubit_gate=CZ, + pre_q0=ops.I, + pre_q1=ops.I, + post_q0=ops.I, + post_q1=ops.I, + support_sweep=True, + ), + ConstantGauge( + two_qubit_gate=CZ, + pre_q0=ops.I, + pre_q1=ops.X, + post_q0=ops.Z, + post_q1=ops.X, + support_sweep=True, + ), + ConstantGauge( + two_qubit_gate=CZ, + pre_q0=ops.I, + pre_q1=ops.Y, + post_q0=ops.Z, + post_q1=ops.Y, + support_sweep=True, + ), + ConstantGauge( + two_qubit_gate=CZ, + pre_q0=ops.I, + pre_q1=ops.Z, + post_q0=ops.I, + post_q1=ops.Z, + support_sweep=True, + ), + ConstantGauge( + two_qubit_gate=CZ, + pre_q0=ops.X, + pre_q1=ops.I, + post_q0=ops.X, + post_q1=ops.Z, + support_sweep=True, + ), + ConstantGauge( + two_qubit_gate=CZ, + pre_q0=ops.X, + pre_q1=ops.X, + post_q0=ops.Y, + post_q1=ops.Y, + support_sweep=True, + ), + ConstantGauge( + two_qubit_gate=CZ, + pre_q0=ops.X, + pre_q1=ops.Y, + post_q0=ops.Y, + post_q1=ops.X, + support_sweep=True, + ), + ConstantGauge( + two_qubit_gate=CZ, + pre_q0=ops.X, + pre_q1=ops.Z, + post_q0=ops.X, + post_q1=ops.I, + support_sweep=True, + ), + ConstantGauge( + two_qubit_gate=CZ, + pre_q0=ops.Y, + pre_q1=ops.I, + post_q0=ops.Y, + post_q1=ops.Z, + support_sweep=True, + ), + ConstantGauge( + two_qubit_gate=CZ, + pre_q0=ops.Y, + pre_q1=ops.X, + post_q0=ops.X, + post_q1=ops.Y, + support_sweep=True, + ), + ConstantGauge( + two_qubit_gate=CZ, + pre_q0=ops.Y, + pre_q1=ops.Y, + post_q0=ops.X, + post_q1=ops.X, + support_sweep=True, + ), + ConstantGauge( + two_qubit_gate=CZ, + pre_q0=ops.Y, + pre_q1=ops.Z, + post_q0=ops.Y, + post_q1=ops.I, + support_sweep=True, + ), + ConstantGauge( + two_qubit_gate=CZ, + pre_q0=ops.Z, + pre_q1=ops.I, + post_q0=ops.Z, + post_q1=ops.I, + support_sweep=True, + ), + ConstantGauge( + two_qubit_gate=CZ, + pre_q0=ops.Z, + pre_q1=ops.X, + post_q0=ops.I, + post_q1=ops.X, + support_sweep=True, + ), + ConstantGauge( + two_qubit_gate=CZ, + pre_q0=ops.Z, + pre_q1=ops.Y, + post_q0=ops.I, + post_q1=ops.Y, + support_sweep=True, + ), + ConstantGauge( + two_qubit_gate=CZ, + pre_q0=ops.Z, + pre_q1=ops.Z, + post_q0=ops.Z, + post_q1=ops.Z, + support_sweep=True, + ), ] ) diff --git a/cirq-core/cirq/transformers/gauge_compiling/cz_gauge_test.py b/cirq-core/cirq/transformers/gauge_compiling/cz_gauge_test.py index 5f4869f30b2..db4eb236c99 100644 --- a/cirq-core/cirq/transformers/gauge_compiling/cz_gauge_test.py +++ b/cirq-core/cirq/transformers/gauge_compiling/cz_gauge_test.py @@ -21,3 +21,4 @@ class TestCZGauge(GaugeTester): two_qubit_gate = cirq.CZ gauge_transformer = CZGaugeTransformer + sweep_must_pass = True diff --git a/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py b/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py index 17c3c620a08..2e0d5290f51 100644 --- a/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py +++ b/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py @@ -14,17 +14,24 @@ """Creates the abstraction for gauge compiling as a cirq transformer.""" -from typing import Callable, Tuple, Optional, Sequence, Union, List +from typing import Callable, Dict, Tuple, Optional, Sequence, Union, List +from itertools import count +from dataclasses import dataclass import abc import itertools import functools +import sympy -from dataclasses import dataclass from attrs import frozen, field import numpy as np from cirq.transformers import transformer_api from cirq import ops, circuits +from cirq.study import sweepable +from cirq.protocols import unitary_protocol +from cirq.protocols.has_unitary_protocol import has_unitary +from cirq.study.sweeps import Points, Zip +from cirq.transformers.analytical_decompositions import single_qubit_decompositions class Gauge(abc.ABC): @@ -72,6 +79,7 @@ class ConstantGauge(Gauge): default=(), converter=lambda g: (g,) if isinstance(g, ops.Gate) else tuple(g) ) swap_qubits: bool = False + support_sweep: bool = False def sample(self, gate: ops.Gate, prng: np.random.Generator) -> "ConstantGauge": return self @@ -201,6 +209,100 @@ def __call__( new_moments.extend(_build_moments(right)) return circuits.Circuit.from_moments(*new_moments) + def as_sweep( + self, + circuit: circuits.AbstractCircuit, + *, + N: int, + context: Optional[transformer_api.TransformerContext] = None, + prng: Optional[np.random.Generator] = None, + ) -> Tuple[circuits.AbstractCircuit, sweepable.Sweepable]: + """Generates a parameterized circuit with *N* sets of sweepable parameters. + + Args: + circuit: The input circuit to be processed by gauge compiling. + N: The number of parameter sets to generate. + context: A `cirq.TransformerContext` storing common configurable options for + the transformers. + prng: A pseudo-random number generator to select a gauge within a gauge cluster. + """ + + rng = np.random.default_rng() if prng is None else prng + if context is None: + context = transformer_api.TransformerContext(deep=False) + if context.deep: + raise ValueError('GaugeTransformer cannot be used with deep=True') + new_moments: List[List[ops.Operation]] = [] # Store parameterized circuits. + values_by_params: Dict[str, List[float]] = {} # map from symbol name to N values. + symbol_count = count() + # Map from "((pre|post),$qid,$moment_id)" to gate parameters. + # E.g. {(post,q1,2): {"x_exponent": "x1", "z_exponent": "z1", "axis_phase": "a1"}} + symbols_by_loc: Dict[Tuple[str, ops.Qid, int], Dict[str, sympy.Symbol]] = {} + + def single_qubit_next_symbol() -> Dict[str, sympy.Symbol]: + sid = next(symbol_count) + return _parameterize(1, sid) + + # Build parameterized circuit. + for moment_id, moment in enumerate(circuit): + center_moment: List[ops.Operation] = [] + left_moment: List[ops.Operation] = [] + right_moment: List[ops.Operation] = [] + for op in moment: + if isinstance(op, ops.TaggedOperation) and set(op.tags).intersection( + context.tags_to_ignore + ): + center_moment.append(op) + continue + if op.gate is not None and op in self.target: + # Build symbols for the gauge, for a 2-qubit gauge, symbols will be built for + # pre/post q0/q1 and the new 2-qubit gate if the 2-qubit gate is updated in + # the gauge compiling. + center_moment.append(op) + for prefix, q in itertools.product(["pre", "post"], op.qubits): + xza_by_symbols = single_qubit_next_symbol() # xza in phased xz gate. + loc = (prefix, q, moment_id) + symbols_by_loc[loc] = xza_by_symbols + new_op = ops.PhasedXZGate(**xza_by_symbols).on(q) + for symbol in xza_by_symbols.values(): + values_by_params.update({str(symbol): []}) + if prefix == "pre": + left_moment.append(new_op) + else: + right_moment.append(new_op) + else: + center_moment.append(op) + new_moments.extend( + [moment for moment in [left_moment, center_moment, right_moment] if moment] + ) + + # Assign values for parameters via randomly chosen GaugeSelector. + for _ in range(N): + for moment_id, moment in enumerate(circuit): + for op in moment: + if isinstance(op, ops.TaggedOperation) and set(op.tags).intersection( + context.tags_to_ignore + ): + continue + if op.gate is not None and len(op.qubits) == 2 and op in self.target: + gauge = self.gauge_selector(rng).sample(op.gate, rng) + if not gauge.support_sweep: + raise NotImplementedError( + f"as_sweep isn't supported for {gauge.two_qubit_gate} gauge" + ) + # Get the params of pre/post q0/q1 gates. + for pre_or_post, idx in itertools.product(["pre", "post"], [0, 1]): + symbols = symbols_by_loc[(pre_or_post, op.qubits[idx], moment_id)] + gates = getattr(gauge, f"{pre_or_post}_q{idx}") + phxz_params = _gate_sequence_to_phxz_params(gates, symbols) + for key, value in phxz_params.items(): + values_by_params[key].append(value) + sweeps: List[Points] = [ + Points(key=key, points=values) for key, values in values_by_params.items() + ] + + return circuits.Circuit.from_moments(*new_moments), Zip(*sweeps) + def _build_moments(operation_by_qubits: List[List[ops.Operation]]) -> List[List[ops.Operation]]: """Builds moments from a list of operations grouped by qubits. @@ -212,3 +314,50 @@ def _build_moments(operation_by_qubits: List[List[ops.Operation]]) -> List[List[ for moment in itertools.zip_longest(*operation_by_qubits): moments.append([op for op in moment if op is not None]) return moments + + +def _parameterize(num_qubits: int, symbol_id: int) -> Dict[str, sympy.Symbol]: + """Returns symbolized parameters for the gate.""" + + if num_qubits == 1: # Convert single qubit gate to parameterized PhasedXZGate. + phased_xz_params = { + "x_exponent": sympy.Symbol(f"x{symbol_id}"), + "z_exponent": sympy.Symbol(f"z{symbol_id}"), + "axis_phase_exponent": sympy.Symbol(f"a{symbol_id}"), + } + return phased_xz_params + raise NotImplementedError("parameterization for non single qubit gates is not supported yet") + + +def _gate_sequence_to_phxz_params( + gates: Tuple[ops.Gate, ...], xza_by_symbols: Dict[str, sympy.Symbol] +) -> Dict[str, float]: + for gate in gates: + if not has_unitary(gate) or gate.num_qubits() != 1: + raise ValueError( + "Invalid gate sequence to be converted to PhasedXZGate." + f"Found incompatiable gate {gate} in sequence." + ) + phxz = ( + single_qubit_decompositions.single_qubit_matrix_to_phxz( + functools.reduce( + np.matmul, [unitary_protocol.unitary(gate) for gate in reversed(gates)] + ) + ) + or ops.I + ) + if phxz is ops.I: # Identity gate + return { + str(xza_by_symbols["x_exponent"]): 0.0, + str(xza_by_symbols["z_exponent"]): 0.0, + str(xza_by_symbols["axis_phase_exponent"]): 0.0, + } + # Check the gate type, needs to be a PhasedXZ gate. + if not isinstance(phxz, ops.PhasedXZGate): + raise ValueError("Failed to convert the gate sequence to a PhasedXZ gate.") + if phxz is not None: + return { + str(xza_by_symbols["x_exponent"]): phxz.x_exponent, + str(xza_by_symbols["z_exponent"]): phxz.z_exponent, + str(xza_by_symbols["axis_phase_exponent"]): phxz.axis_phase_exponent, + } diff --git a/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling_test.py b/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling_test.py index 1453d17291c..da5fd224488 100644 --- a/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling_test.py +++ b/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling_test.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest.mock import pytest import numpy as np import cirq -from cirq.transformers.gauge_compiling import GaugeTransformer, CZGaugeTransformer +from cirq.transformers.gauge_compiling import ( + GaugeTransformer, + CZGaugeTransformer, + ConstantGauge, + GaugeSelector, +) +from cirq.transformers.analytical_decompositions import single_qubit_decompositions def test_deep_transformation_not_supported(): @@ -25,10 +32,19 @@ def test_deep_transformation_not_supported(): cirq.Circuit(), context=cirq.TransformerContext(deep=True) ) + with pytest.raises(ValueError, match="cannot be used with deep=True"): + _ = GaugeTransformer(target=cirq.CZ, gauge_selector=lambda _: None).as_sweep( + cirq.Circuit(), context=cirq.TransformerContext(deep=True), N=1 + ) + def test_ignore_tags(): c = cirq.Circuit(cirq.CZ(*cirq.LineQubit.range(2)).with_tags('foo')) assert c == CZGaugeTransformer(c, context=cirq.TransformerContext(tags_to_ignore={"foo"})) + parameterized_circuit, _ = CZGaugeTransformer.as_sweep( + c, context=cirq.TransformerContext(tags_to_ignore={"foo"}), N=1 + ) + assert c == parameterized_circuit def test_target_can_be_gateset(): @@ -39,3 +55,71 @@ def test_target_can_be_gateset(): ) want = cirq.Circuit(cirq.Y.on_each(qs), cirq.CZ(*qs), cirq.X.on_each(qs)) assert transformer(c, prng=np.random.default_rng(0)) == want + + +def test_as_sweep_multi_pre_or_multi_post(): + transformer = GaugeTransformer( + target=cirq.CZ, + gauge_selector=GaugeSelector( + gauges=[ + ConstantGauge( + two_qubit_gate=cirq.CZ, + support_sweep=True, + pre_q0=[cirq.X, cirq.X], + post_q0=[cirq.Z], + pre_q1=[cirq.Y], + post_q1=[cirq.Y, cirq.Y, cirq.Y], + ) + ] + ), + ) + qs = cirq.LineQubit.range(2) + input_circuit = cirq.Circuit(cirq.CZ(*qs)) + parameterized_circuit, sweeps = transformer.as_sweep(input_circuit, N=1) + + for params in sweeps: + compiled_circuit = cirq.resolve_parameters(parameterized_circuit, params) + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + input_circuit, compiled_circuit, qubit_map={q: q for q in input_circuit.all_qubits()} + ) + + +def test_as_sweep_invalid_gauge_sequence(): + transfomer = GaugeTransformer( + target=cirq.CZ, + gauge_selector=GaugeSelector( + gauges=[ + ConstantGauge( + two_qubit_gate=cirq.CZ, + support_sweep=True, + pre_q0=[cirq.measure], + post_q0=[cirq.Z], + pre_q1=[cirq.X], + post_q1=[cirq.Z], + ) + ] + ), + ) + qs = cirq.LineQubit.range(2) + c = cirq.Circuit(cirq.CZ(*qs)) + with pytest.raises(ValueError, match="Invalid gate sequence to be converted to PhasedXZGate."): + transfomer.as_sweep(c, N=1) + + +def test_as_sweep_convert_to_phxz_failed(): + qs = cirq.LineQubit.range(2) + c = cirq.Circuit(cirq.CZ(*qs)) + + def mock_single_qubit_matrix_to_phxz(*args, **kwargs): + # Return an non PhasedXZ gate, so we expect errors from as_sweep(). + return cirq.X + + with unittest.mock.patch.object( + single_qubit_decompositions, + "single_qubit_matrix_to_phxz", + new=mock_single_qubit_matrix_to_phxz, + ): + with pytest.raises( + ValueError, match="Failed to convert the gate sequence to a PhasedXZ gate." + ): + _ = CZGaugeTransformer.as_sweep(c, context=cirq.TransformerContext(), N=1) diff --git a/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling_test_utils.py b/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling_test_utils.py index a8e431c9765..449811800e2 100644 --- a/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling_test_utils.py +++ b/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling_test_utils.py @@ -27,6 +27,7 @@ class GaugeTester: two_qubit_gate: cirq.Gate gauge_transformer: GaugeTransformer must_fail: bool = False + sweep_must_pass: bool = False @pytest.mark.parametrize( ['generation_seed', 'transformation_seed'], @@ -73,6 +74,42 @@ def test_all_gauges(self, mock_select, seed): else: _check_equivalent_with_error_message(c, nc, gauge) + def test_sweep(self): + qubits = cirq.LineQubit.range(3) + + if not self.sweep_must_pass: + with pytest.raises(NotImplementedError): + self.gauge_transformer.as_sweep( + cirq.Circuit(cirq.Moment(self.two_qubit_gate(*qubits[:2]))), N=1 + ) + return + + input_circuit = cirq.Circuit( + cirq.Moment(cirq.H(qubits[0])), + cirq.Moment(self.two_qubit_gate(*qubits[:2])), + cirq.Moment(self.two_qubit_gate(*qubits[1:])), + cirq.Moment([cirq.H(q) for q in qubits]), + cirq.Moment([cirq.measure(q) for q in qubits]), + ) + + n_samples = 5 + parameterized_circuit, sweeps = self.gauge_transformer.as_sweep(input_circuit, N=n_samples) + + # Check the parameterized circuit and N set of parameters. + assert cirq.is_parameterized(parameterized_circuit) + simulator = cirq.Simulator() + results = simulator.run_sweep(parameterized_circuit, sweeps) + assert len(results) == n_samples + + # Check compilied circuits have the same unitary as the orig circuit. + for params in sweeps: + compiled_circuit = cirq.resolve_parameters(parameterized_circuit, params) + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + input_circuit[:-1], + compiled_circuit[:-1], + qubit_map={q: q for q in input_circuit.all_qubits()}, + ) + def _check_equivalent_with_error_message(c: cirq.AbstractCircuit, nc: cirq.AbstractCircuit, gauge): try: diff --git a/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling_test_utils_test.py b/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling_test_utils_test.py index 4c54d23e0cb..7f79565b546 100644 --- a/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling_test_utils_test.py +++ b/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling_test_utils_test.py @@ -26,7 +26,15 @@ def _unitary_(self) -> np.ndarray: return self.unitary +class ExampleSweepGate(cirq.testing.TwoQubitGate): + unitary = cirq.unitary(cirq.CZ) + + def _unitary_(self) -> np.ndarray: + return self.unitary + + _EXAMPLE_TARGET = ExampleGate() +_EXAMPLE_SWEEP_TARGET = ExampleSweepGate() _GOOD_TRANSFORMER = GaugeTransformer( target=_EXAMPLE_TARGET, @@ -40,6 +48,22 @@ def _unitary_(self) -> np.ndarray: ), ) +_TRANSFORMER_WITH_SWEEP = GaugeTransformer( + target=_EXAMPLE_SWEEP_TARGET, + gauge_selector=GaugeSelector( + gauges=[ + ConstantGauge( + two_qubit_gate=_EXAMPLE_SWEEP_TARGET, + pre_q0=cirq.Z, + pre_q1=cirq.Z, + post_q0=cirq.Z, + post_q1=cirq.Z, + support_sweep=True, + ) + ] + ), +) + class TestValidTransformer(GaugeTester): two_qubit_gate = _EXAMPLE_TARGET @@ -50,3 +74,9 @@ class TestInvalidTransformer(GaugeTester): two_qubit_gate = _EXAMPLE_TARGET gauge_transformer = _BAD_TRANSFORMER must_fail = True + + +class TestSweep(GaugeTester): + two_qubit_gate = _EXAMPLE_SWEEP_TARGET + gauge_transformer = _TRANSFORMER_WITH_SWEEP + sweep_must_pass = True