diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index f86a46130c9..83d74e1f55c 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -9,9 +9,6 @@ * Add a decomposition for multi-controlled global phases into a one-less-controlled phase shift. [(#6936)](https://github.com/PennyLaneAI/pennylane/pull/6936) -* Add a `qml.capture.pause()` context manager for pausing program capture in an error-safe way. - [(#6911)](https://github.com/PennyLaneAI/pennylane/pull/6911) - * `qml.StatePrep` now accepts sparse state vectors. Users can create `StatePrep` using `scipy.sparse.csr_matrix`. Note that non-zero `pad_with` is forbidden. [(#6863)](https://github.com/PennyLaneAI/pennylane/pull/6863) @@ -32,37 +29,6 @@ is greater than `0.4.28`. [(#6864)](https://github.com/PennyLaneAI/pennylane/pull/6864) -* Python control flow (`if/else`, `for`, `while`) is now supported when program capture is enabled by setting - `autograph=True` at the QNode level. - [(#6837)](https://github.com/PennyLaneAI/pennylane/pull/6837) - - ```python - qml.capture.enable() - - dev = qml.device("default.qubit", wires=[0, 1, 2]) - - @qml.qnode(dev, autograph=True) - def circuit(num_loops: int): - for i in range(num_loops): - if i % 2 == 0: - qml.H(i) - else: - qml.RX(1,i) - return qml.state() - ``` - - ```pycon - >>> print(qml.draw(circuit)(num_loops=3)) - 0: ──H────────┤ State - 1: ──RX(1.00)─┤ State - 2: ──H────────┤ State - >>> circuit(3) - Array([0.43879125+0.j , 0.43879125+0.j , - 0. -0.23971277j, 0. -0.23971277j, - 0.43879125+0.j , 0.43879125+0.j , - 0. -0.23971277j, 0. -0.23971277j], dtype=complex64) - ``` - * Added the `qml.workflow.construct_execution_config(qnode)(*args,**kwargs)` helper function. Users can now construct the execution configuration from a particular `QNode` instance. [(#6901)](https://github.com/PennyLaneAI/pennylane/pull/6901) @@ -91,13 +57,6 @@ convert_to_numpy=True) ``` -* The higher order primitives in program capture can now accept inputs with abstract shapes. - [(#6786)](https://github.com/PennyLaneAI/pennylane/pull/6786) - -* The `PlxprInterpreter` classes can now handle creating dynamic arrays via `jnp.ones`, `jnp.zeros`, - `jnp.arange`, and `jnp.full`. - [#6865)](https://github.com/PennyLaneAI/pennylane/pull/6865) - * `QNode` objects now have an `update` method that allows for re-configuring settings like `diff_method`, `mcm_method`, and more. This allows for easier on-the-fly adjustments to workflows. Any arguments not specified will retain their original value. [(#6803)](https://github.com/PennyLaneAI/pennylane/pull/6803) @@ -160,9 +119,56 @@ * `null.qubit` can now execute jaxpr. [(#6924)](https://github.com/PennyLaneAI/pennylane/pull/6924) -* Autograph can now be used with custom operations defined outside of the pennylane namespace. +

Capturing and representing hybrid programs

+ +* Implemented a `compute_plxpr_decomposition` method in the `qml.operation.Operator` class to apply dynamic decompositions + with program capture enabled. + [(#6859)](https://github.com/PennyLaneAI/pennylane/pull/6859) + + * Autograph can now be used with custom operations defined outside of the pennylane namespace. [(#6931)](https://github.com/PennyLaneAI/pennylane/pull/6931) + * Add a `qml.capture.pause()` context manager for pausing program capture in an error-safe way. + [(#6911)](https://github.com/PennyLaneAI/pennylane/pull/6911) + +* Python control flow (`if/else`, `for`, `while`) is now supported when program capture is enabled by setting + `autograph=True` at the QNode level. + [(#6837)](https://github.com/PennyLaneAI/pennylane/pull/6837) + + ```python + qml.capture.enable() + + dev = qml.device("default.qubit", wires=[0, 1, 2]) + + @qml.qnode(dev, autograph=True) + def circuit(num_loops: int): + for i in range(num_loops): + if i % 2 == 0: + qml.H(i) + else: + qml.RX(1,i) + return qml.state() + ``` + + ```pycon + >>> print(qml.draw(circuit)(num_loops=3)) + 0: ──H────────┤ State + 1: ──RX(1.00)─┤ State + 2: ──H────────┤ State + >>> circuit(3) + Array([0.43879125+0.j , 0.43879125+0.j , + 0. -0.23971277j, 0. -0.23971277j, + 0.43879125+0.j , 0.43879125+0.j , + 0. -0.23971277j, 0. -0.23971277j], dtype=complex64) + ``` + +* The higher order primitives in program capture can now accept inputs with abstract shapes. + [(#6786)](https://github.com/PennyLaneAI/pennylane/pull/6786) + +* The `PlxprInterpreter` classes can now handle creating dynamic arrays via `jnp.ones`, `jnp.zeros`, + `jnp.arange`, and `jnp.full`. + [#6865)](https://github.com/PennyLaneAI/pennylane/pull/6865) +

Breaking changes 💔

* `MultiControlledX` no longer accepts strings as control values. diff --git a/pennylane/operation.py b/pennylane/operation.py index 17546747dc4..f22f658ee6e 100644 --- a/pennylane/operation.py +++ b/pennylane/operation.py @@ -1349,6 +1349,33 @@ def compute_decomposition( """ raise DecompositionUndefinedError + @classproperty + def has_plxpr_decomposition(cls) -> bool: + """Whether or not the Operator returns a defined plxpr decomposition.""" + return cls.compute_plxpr_decomposition != Operator.compute_plxpr_decomposition + + @staticmethod + def compute_plxpr_decomposition(*args, **hyperparameters) -> None: + r"""Experimental method to compute the dynamic decomposition of the operator with program capture enabled. + + When the program capture feature is enabled with ``qml.capture.enable()``, the decomposition of the operator + is computed with this method if it is defined. Otherwise, the :meth:`~.Operator.compute_decomposition` method is used. + + If this method is defined, the control flow operations within the method are recorded in the JAX representation + of the operator's decomposition. + + This method is experimental and subject to change. + + .. seealso:: :meth:`~.Operator.compute_decomposition`. + + Args: + *args (list): positional arguments passed to the operator, including trainable parameters and wires + **hyperparameters (dict): non-trainable hyperparameters of the operator, as stored in the ``hyperparameters`` attribute + + """ + + raise DecompositionUndefinedError + # pylint: disable=no-self-argument, comparison-with-callable @classproperty def has_diagonalizing_gates(cls) -> bool: diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index bf5fae0a32c..e73384bdc0b 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -157,7 +157,18 @@ def interpret_operation_eqn(self, eqn): with qml.QueuingManager.stop_recording(): op = eqn.primitive.impl(*invals, **eqn.params) if eqn.outvars[0].__class__.__name__ == "DropVar": - return self.decompose_operation(op) + + if op.has_plxpr_decomposition: + + args = (*op.parameters, *op.wires) + qml.capture.run_autograph(op.compute_plxpr_decomposition)( + *args, **op.hyperparameters + ) + + else: + + return self.decompose_operation(op) + return op # pylint: disable=unused-variable,missing-function-docstring diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py new file mode 100644 index 00000000000..84a7fc93ad2 --- /dev/null +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -0,0 +1,666 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the ``DecomposeInterpreter`` class with dynamic decompositions.""" +# pylint:disable=protected-access,unused-argument, wrong-import-position, no-value-for-parameter, too-few-public-methods, wrong-import-order +import pytest + +import pennylane as qml + +jax = pytest.importorskip("jax") + +from functools import partial + +from pennylane.capture import expand_plxpr_transforms +from pennylane.capture.primitives import cond_prim, for_loop_prim, qnode_prim, while_loop_prim +from pennylane.operation import Operation +from pennylane.transforms.decompose import DecomposeInterpreter + +pytestmark = [pytest.mark.jax, pytest.mark.usefixtures("enable_disable_plxpr")] + + +class SimpleCustomOp(Operation): + """Simple custom operation that contains a single gate in its decomposition""" + + num_wires = 1 + num_params = 0 + + def _init__(self, wires, id=None): + super().__init__(wires=wires, id=id) + + @staticmethod + def compute_plxpr_decomposition(wires): + + return qml.Hadamard(wires=wires) + + +const = jax.numpy.array(0.1) + + +class CustomOpConstHyperparams(Operation): + """Custom operation that contains constants and hyperparameters in its decomposition""" + + num_wires = 4 + num_params = 1 + + def __init__(self, phi, wires, id=None): + + self._hyperparameters = { + "key": const, + "CNOT": qml.CNOT, + "RX": qml.RX, + "phi": phi, + } + + super().__init__(phi, wires=wires, id=id) + + @staticmethod + def compute_plxpr_decomposition(*args, **hyperparameters): + + phi = args[0] + wires = args[1:] + + hyperparameters["CNOT"](wires=[wires[0], wires[1]]) + hyperparameters["RX"](phi, wires=wires[2]) + hyperparameters["RX"](hyperparameters["key"], wires=wires[0]) + hyperparameters["RX"](const, wires=wires[3]) + + qml.RY(hyperparameters["key"], wires[0]) + qml.RZ(hyperparameters["phi"], wires[2]) + + +class CustomOpMultiWire(Operation): + """Custom operation that acts on multiple wires""" + + num_wires = 4 + num_params = 1 + + def __init__(self, phi, wires, id=None): + + self.hyperparameters["key_1"] = 0.1 + self.hyperparameters["key_2"] = 0.2 + + super().__init__(phi, wires=wires, id=id) + + @staticmethod + def compute_plxpr_decomposition(*args, **hyperparameters): + + phi = args[0] + wires = args[1:] + + qml.CNOT([wires[0], wires[1]]) + qml.DoubleExcitation(phi, wires) + qml.CNOT([wires[0], wires[1]]) + qml.RX(hyperparameters["key_1"], wires[0]) + qml.RY(phi, wires[1]) + qml.RZ(phi, wires[2]) + qml.RX(hyperparameters["key_2"], wires[3]) + + +class CustomOpCond(Operation): + """Custom operation that contains a conditional in its decomposition""" + + num_wires = 1 + num_params = 1 + + def __init__(self, phi, wires, id=None): + super().__init__(phi, wires=wires, id=id) + + @staticmethod + def compute_plxpr_decomposition(phi, wires): + + def true_fn(phi, wires): + qml.RX(phi, wires=wires) + + def false_fn(phi, wires): + qml.RY(phi, wires=wires) + + qml.cond(phi > 0.5, true_fn, false_fn)(phi, wires) + + +class CustomOpForLoop(Operation): + """Custom operation that contains a for loop in its decomposition""" + + num_wires = 1 + num_params = 1 + + def __init__(self, phi, wires, id=None): + super().__init__(phi, wires=wires, id=id) + + @staticmethod + def compute_plxpr_decomposition(phi, wires): + + @qml.for_loop(0, 3, 1) + def loop_rx(i, phi): + qml.RX(phi, wires) + return jax.numpy.sin(phi) + + # pylint: disable=unused-variable + loop_rx(phi) + + +class CustomOpWhileLoop(Operation): + """Custom operation that contains a while loop in its decomposition""" + + num_wires = 1 + num_params = 1 + + def __init__(self, phi, wires, id=None): + super().__init__(phi, wires=wires, id=id) + + @staticmethod + def compute_plxpr_decomposition(phi, wires): + + def while_f(i): + return i < 3 + + @qml.while_loop(while_f) + def loop_fn(i): + qml.RX(phi, wires) + return i + 1 + + _ = loop_fn(0) + + +class CustomOpNestedCond(Operation): + """Custom operation that contains a nested conditional in its decomposition""" + + num_wires = 1 + num_params = 1 + + def __init__(self, phi, wires, id=None): + super().__init__(phi, wires=wires, id=id) + + @staticmethod + def compute_plxpr_decomposition(phi, wires): + + def true_fn(phi, wires): + + @qml.for_loop(0, 3, 1) + def loop_rx(i, phi): + qml.RX(phi, wires) + return jax.numpy.sin(phi) + + # pylint: disable=unused-variable + loop_rx(phi) + + def false_fn(phi, wires): + + def while_f(i): + return i < 3 + + @qml.while_loop(while_f) + def loop_fn(i): + qml.RX(phi, wires) + return i + 1 + + _ = loop_fn(0) + + qml.cond(phi > 0.5, true_fn, false_fn)(phi, wires) + + qml.RX(phi, wires=wires) + + +class CustomOpAutograph(Operation): + """Custom operation that contains a nested conditional in its decomposition""" + + num_wires = 1 + num_params = 1 + + def __init__(self, phi, wires, id=None): + super().__init__(phi, wires=wires, id=id) + + @staticmethod + def compute_plxpr_decomposition(phi, wires): + + if phi > 0.5: + qml.RX(phi, wires=wires) + + else: + qml.RY(phi, wires=wires) + + +class TestDynamicDecomposeInterpreter: + """Tests for the DynamicDecomposeInterpreter class""" + + def test_error_no_plxpr_decomposition(self): + """Test that an error is raised if an operator does not have a plxpr decomposition.""" + + with pytest.raises(qml.operation.DecompositionUndefinedError): + qml.RX(0.1, 0).compute_plxpr_decomposition() + + def test_no_plxpr_decomposition(self): + """Test that a function with a custom operation that does not have a plxpr decomposition is not decomposed.""" + + @DecomposeInterpreter() + def f(x): + qml.RY(x, wires=0) + + jaxpr = jax.make_jaxpr(f)(0.5) + assert len(jaxpr.eqns) == 1 + assert jaxpr.eqns[0].primitive == qml.RY._primitive + + def test_function_simple(self): + """Test that a function with a custom operation is correctly decomposed.""" + + @DecomposeInterpreter() + def f(): + qml.RY(0.1, wires=0) + SimpleCustomOp(wires=0) + return qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(f)() + assert len(jaxpr.eqns) == 4 + assert jaxpr.eqns[0].primitive == qml.RY._primitive + assert jaxpr.eqns[1].primitive == qml.Hadamard._primitive + assert jaxpr.eqns[2].primitive == qml.PauliZ._primitive + assert jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive + + ############################ + ### QNode tests + ############################ + + @pytest.mark.parametrize("autograph", [True, False]) + def test_qnode_simple(self, autograph): + """Test that a QNode with a custom operation is correctly decomposed.""" + + @DecomposeInterpreter() + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=autograph) + def circuit(): + qml.RY(0.1, wires=0) + SimpleCustomOp(wires=0) + return qml.expval(qml.Z(0)) + + jaxpr = qml.capture.make_plxpr(circuit)() + + assert jaxpr.eqns[0].primitive == qnode_prim + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[0].primitive == qml.RY._primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.Hadamard._primitive + assert qfunc_jaxpr.eqns[2].primitive == qml.PauliZ._primitive + assert qfunc_jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive + + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts) + + @qml.qnode(device=qml.device("default.qubit", wires=2)) + def circuit_comparison(): + qml.RY(0.1, wires=0) + qml.Hadamard(wires=0) + return qml.expval(qml.Z(0)) + + assert qml.math.allclose(*result, circuit_comparison()) + + @pytest.mark.parametrize("autograph", [True, False]) + @pytest.mark.parametrize("wires", [[0, 1, 2, 3], [2, 3, 1, 0]]) + def test_multi_wire(self, wires, autograph): + """Test that a QNode with a multi-wire custom operation is correctly decomposed.""" + + @DecomposeInterpreter() + @qml.qnode(device=qml.device("default.qubit", wires=4), autograph=autograph) + def circuit(x, wires): + CustomOpMultiWire(x, wires=wires) + return qml.expval(qml.Z(0)), qml.probs(wires=1), qml.var(qml.Z(2)), qml.state() + + jaxpr = jax.make_jaxpr(circuit)(0.5, wires=wires) + + assert jaxpr.eqns[0].primitive == qnode_prim + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[0].primitive == qml.CNOT._primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.DoubleExcitation._primitive + assert qfunc_jaxpr.eqns[2].primitive == qml.CNOT._primitive + assert qfunc_jaxpr.eqns[3].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[4].primitive == qml.RY._primitive + assert qfunc_jaxpr.eqns[5].primitive == qml.RZ._primitive + assert qfunc_jaxpr.eqns[6].primitive == qml.RX._primitive + + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.5, *wires) + + @qml.qnode(device=qml.device("default.qubit", wires=4)) + def circuit_comparison(x, wires): + qml.CNOT([wires[0], wires[1]]) + qml.DoubleExcitation(x, wires) + qml.CNOT([wires[0], wires[1]]) + qml.RX(0.1, wires=wires[0]) + qml.RY(x, wires=wires[1]) + qml.RZ(x, wires=wires[2]) + qml.RX(0.2, wires=wires[3]) + return qml.expval(qml.Z(0)), qml.probs(wires=1), qml.var(qml.Z(2)), qml.state() + + comparison_result = circuit_comparison(0.5, wires) + for res, comp in zip(result, comparison_result): + assert qml.math.allclose(res, comp) + + @pytest.mark.parametrize("autograph", [True, False]) + @pytest.mark.parametrize("wires", [[0, 1, 2, 3], [2, 3, 1, 0]]) + @pytest.mark.parametrize("x", [0.2, 0.8]) + def test_qnode_const_hyperparams(self, wires, x, autograph): + """Test that a QNode with a constant in the custom operation is correctly decomposed.""" + + @DecomposeInterpreter() + @qml.qnode(device=qml.device("default.qubit", wires=4), autograph=autograph) + def circuit(x, wires): + CustomOpConstHyperparams(x, wires=wires) + return qml.expval(qml.Z(0)), qml.probs(wires=1), qml.var(qml.Z(2)), qml.state() + + jaxpr = jax.make_jaxpr(circuit)(x, wires=wires) + + assert jaxpr.eqns[0].primitive == qnode_prim + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[0].primitive == qml.CNOT._primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[2].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[3].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[4].primitive == qml.RY._primitive + assert qfunc_jaxpr.eqns[5].primitive == qml.RZ._primitive + + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, *wires) + + @qml.qnode(device=qml.device("default.qubit", wires=4)) + def circuit_comparison(x, wires): + qml.CNOT([wires[0], wires[1]]) + qml.RX(x, wires=wires[2]) + qml.RX(0.1, wires=wires[0]) + qml.RX(0.1, wires=wires[3]) + qml.RY(0.1, wires=wires[0]) + qml.RZ(x, wires=wires[2]) + return qml.expval(qml.Z(0)), qml.probs(wires=1), qml.var(qml.Z(2)), qml.state() + + comparison_result = circuit_comparison(x, wires) + for res, comp in zip(result, comparison_result): + assert qml.math.allclose(res, comp) + + @pytest.mark.parametrize("autograph", [True, False]) + @pytest.mark.parametrize("wire", [0, 1]) + @pytest.mark.parametrize("x", [0.2, 0.8]) + def test_qnode_cond(self, x, wire, autograph): + """Test that a QNode with a conditional custom operation is correctly decomposed.""" + + @DecomposeInterpreter() + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=autograph) + def circuit(x, wire): + CustomOpCond(x, wires=wire) + return qml.expval(qml.Z(wires=wire)) + + jaxpr = jax.make_jaxpr(circuit)(x, wire=wire) + + assert jaxpr.eqns[0].primitive == qnode_prim + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[1].primitive == cond_prim + assert ( + qfunc_jaxpr.eqns[1].params["jaxpr_branches"][0].eqns[0].primitive == qml.RX._primitive + ) + assert ( + qfunc_jaxpr.eqns[1].params["jaxpr_branches"][1].eqns[0].primitive == qml.RY._primitive + ) + assert qfunc_jaxpr.eqns[2].primitive == qml.PauliZ._primitive + assert qfunc_jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive + + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, wire) + + @qml.qnode(device=qml.device("default.qubit", wires=2)) + def circuit_comparison(x, wire): + def true_fn(x, wire): + qml.RX(x, wires=wire) + + def false_fn(x, wire): + qml.RY(x, wires=wire) + + qml.cond(x > 0.5, true_fn, false_fn)(x, wire) + + return qml.expval(qml.Z(wires=wire)) + + assert qml.math.allclose(*result, circuit_comparison(x, wire)) + + @pytest.mark.parametrize("autograph", [True, False]) + @pytest.mark.parametrize("wire", [0, 1]) + def test_qnode_for_loop(self, wire, autograph): + """Test that a QNode with a for loop custom operation is correctly decomposed.""" + + @DecomposeInterpreter() + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=autograph) + def circuit(x, wire): + CustomOpForLoop(x, wires=wire) + return qml.expval(qml.Z(wires=wire)) + + jaxpr = jax.make_jaxpr(circuit)(0.5, wire) + + assert jaxpr.eqns[0].primitive == qnode_prim + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[0].primitive == for_loop_prim + assert qfunc_jaxpr.eqns[0].params["jaxpr_body_fn"].eqns[0].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.PauliZ._primitive + assert qfunc_jaxpr.eqns[2].primitive == qml.measurements.ExpectationMP._obs_primitive + + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.5, wire) + + @qml.qnode(device=qml.device("default.qubit", wires=2)) + def circuit_comparison(x, wire): + @qml.for_loop(0, 3, 1) + def loop_rx(i, phi): + qml.RX(phi, wires=wire) + return jax.numpy.sin(phi) + + # pylint: disable=unused-variable + loop_rx(x) + + return qml.expval(qml.Z(wires=wire)) + + assert qml.math.allclose(*result, circuit_comparison(0.5, wire)) + + @pytest.mark.parametrize("autograph", [True, False]) + @pytest.mark.parametrize("wire", [0, 1]) + def test_qnode_while_loop(self, wire, autograph): + """Test that a QNode with a while loop custom operation is correctly decomposed.""" + + @DecomposeInterpreter() + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=autograph) + def circuit(x, wire): + CustomOpWhileLoop(x, wires=wire) + return qml.expval(qml.Z(wires=wire)) + + jaxpr = jax.make_jaxpr(circuit)(0.5, wire) + + assert jaxpr.eqns[0].primitive == qnode_prim + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[0].primitive == while_loop_prim + assert qfunc_jaxpr.eqns[0].params["jaxpr_body_fn"].eqns[0].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.PauliZ._primitive + assert qfunc_jaxpr.eqns[2].primitive == qml.measurements.ExpectationMP._obs_primitive + + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.5, wire) + + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=False) + def circuit_comparison(x, wire): + @qml.while_loop(lambda i: i < 3) + def loop_fn(i): + qml.RX(x, wires=wire) + return i + 1 + + _ = loop_fn(0) + + return qml.expval(qml.Z(wires=wire)) + + assert qml.math.allclose(*result, circuit_comparison(0.5, wire)) + + @pytest.mark.parametrize("autograph", [True, False]) + @pytest.mark.parametrize("wire", [0, 1]) + @pytest.mark.parametrize("x", [0.2, 0.8]) + def test_qnode_nested_cond(self, x, wire, autograph): + """Test that a QNode with a nested conditional custom operation is correctly decomposed.""" + + @DecomposeInterpreter() + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=autograph) + def circuit(x, wire): + CustomOpNestedCond(x, wires=wire) + return qml.expval(qml.Z(wires=wire)) + + jaxpr = jax.make_jaxpr(circuit)(x, wire) + + assert jaxpr.eqns[0].primitive == qnode_prim + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[1].primitive == cond_prim + assert qfunc_jaxpr.eqns[1].params["jaxpr_branches"][0].eqns[0].primitive == for_loop_prim + assert qfunc_jaxpr.eqns[1].params["jaxpr_branches"][1].eqns[0].primitive == while_loop_prim + assert qfunc_jaxpr.eqns[2].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[3].primitive == qml.PauliZ._primitive + assert qfunc_jaxpr.eqns[4].primitive == qml.measurements.ExpectationMP._obs_primitive + + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, wire) + + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=False) + def circuit_comparison(x, wire): + def true_fn(x, wire): + + @qml.for_loop(0, 3, 1) + def loop_rx(i, phi): + qml.RX(phi, wires=wire) + return jax.numpy.sin(phi) + + # pylint: disable=unused-variable + loop_rx(x) + + def false_fn(x, wire): + @qml.while_loop(lambda i: i < 3) + def loop_fn(i): + qml.RX(x, wires=wire) + return i + 1 + + _ = loop_fn(0) + + qml.cond(x > 0.5, true_fn, false_fn)(x, wire) + qml.RX(x, wires=wire) + return qml.expval(qml.Z(wires=wire)) + + assert qml.math.allclose(*result, circuit_comparison(x, wire)) + + @pytest.mark.parametrize("wire", [0, 1]) + @pytest.mark.parametrize("x", [0.2, 0.8]) + def test_qnode_autograph(self, x, wire): + """Test that a QNode with a nested conditional custom operation is correctly decomposed.""" + + @DecomposeInterpreter() + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=True) + def circuit(x, wire): + CustomOpAutograph(x, wires=wire) + return qml.expval(qml.Z(wires=wire)) + + jaxpr = jax.make_jaxpr(circuit)(x, wire) + + assert jaxpr.eqns[0].primitive == qnode_prim + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[1].primitive == cond_prim + assert ( + qfunc_jaxpr.eqns[1].params["jaxpr_branches"][0].eqns[0].primitive == qml.RX._primitive + ) + assert ( + qfunc_jaxpr.eqns[1].params["jaxpr_branches"][1].eqns[0].primitive == qml.RY._primitive + ) + assert qfunc_jaxpr.eqns[2].primitive == qml.PauliZ._primitive + assert qfunc_jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive + + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, wire) + + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=True) + def circuit_comparison(x, wire): + + if x > 0.5: + qml.RX(x, wires=wire) + else: + qml.RY(x, wires=wire) + + return qml.expval(qml.Z(wires=wire)) + + # Autograph requires to capture the function first + jaxpr_comparison = jax.make_jaxpr(circuit_comparison)(x, wire) + result_comparison = jax.core.eval_jaxpr( + jaxpr_comparison.jaxpr, jaxpr_comparison.consts, x, wire + ) + + assert qml.math.allclose(*result, *result_comparison) + + +class TestExpandPlxprTransformsDynamicDecompositions: + """Unit tests for ``expand_plxpr_transforms`` with dynamic decompositions.""" + + def test_expand_plxpr_transforms_simple(self): + + @partial(qml.transforms.decompose) + def circuit(): + SimpleCustomOp(wires=0) + return qml.probs(wires=[0, 1]) + + jaxpr = jax.make_jaxpr(circuit)() + + assert jaxpr.eqns[0].primitive == qml.transforms.decompose._primitive + + transformed_f = expand_plxpr_transforms(circuit) + transformed_jaxpr = jax.make_jaxpr(transformed_f)() + + assert transformed_jaxpr.eqns[0].primitive == qml.Hadamard._primitive + assert ( + transformed_jaxpr.eqns[1].primitive == qml.measurements.ProbabilityMP._wires_primitive + ) + + def test_expand_plxpr_transforms_cond(self): + @partial(qml.transforms.decompose) + def circuit(): + CustomOpCond(0.5, wires=0) + return qml.probs(wires=[0, 1]) + + jaxpr = jax.make_jaxpr(circuit)() + + assert jaxpr.eqns[0].primitive == qml.transforms.decompose._primitive + + transformed_f = expand_plxpr_transforms(circuit) + transformed_jaxpr = jax.make_jaxpr(transformed_f)() + + assert transformed_jaxpr.eqns[0].primitive == cond_prim + assert ( + transformed_jaxpr.eqns[1].primitive == qml.measurements.ProbabilityMP._wires_primitive + ) + + def test_expand_plxpr_transforms_for_loop(self): + @partial(qml.transforms.decompose) + def circuit(): + CustomOpForLoop(0.5, wires=0) + return qml.probs(wires=[0, 1]) + + jaxpr = jax.make_jaxpr(circuit)() + + assert jaxpr.eqns[0].primitive == qml.transforms.decompose._primitive + + transformed_f = expand_plxpr_transforms(circuit) + transformed_jaxpr = jax.make_jaxpr(transformed_f)() + + assert transformed_jaxpr.eqns[0].primitive == for_loop_prim + assert ( + transformed_jaxpr.eqns[1].primitive == qml.measurements.ProbabilityMP._wires_primitive + ) + + def test_expand_plxpr_transforms_while_loop(self): + @partial(qml.transforms.decompose) + def circuit(): + CustomOpWhileLoop(0.5, wires=0) + return qml.probs(wires=[0, 1]) + + jaxpr = jax.make_jaxpr(circuit)() + + assert jaxpr.eqns[0].primitive == qml.transforms.decompose._primitive + + transformed_f = expand_plxpr_transforms(circuit) + transformed_jaxpr = jax.make_jaxpr(transformed_f)() + + assert transformed_jaxpr.eqns[0].primitive == while_loop_prim + assert ( + transformed_jaxpr.eqns[1].primitive == qml.measurements.ProbabilityMP._wires_primitive + )