From 696acf3d99ab9865f8a87201ed437b67dfafe1ec Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 20 Jan 2025 10:56:48 -0500 Subject: [PATCH 01/33] E.C. From 9cdfae96c1a32768d94deb9781efc0b69fb05b2c Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Mon, 20 Jan 2025 20:43:51 -0500 Subject: [PATCH 02/33] Creating an empty `DynamicDecomposeInterpreter` c;ass --- pennylane/transforms/decompose.py | 49 +++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index 9b28f5c891f..c53d86a6f86 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -187,6 +187,55 @@ def wrapper(*inner_args): DecomposeInterpreter, decompose_plxpr_to_plxpr = _get_plxpr_decompose() +@lru_cache +def _get_plxpr_dynamic_decompose(): # pylint: disable=missing-docstring + try: + # pylint: disable=import-outside-toplevel + from jax import make_jaxpr + except ImportError: # pragma: no cover + return None, None + + # pylint: disable=redefined-outer-name + + class DynamicDecomposeInterpreter(qml.capture.PlxprInterpreter): + """ """ + + def __init__(self, gate_set=None, max_expansion=None): + self.max_expansion = max_expansion + + if gate_set is None: + gate_set = set(qml.ops.__all__) + + if isinstance(gate_set, (str, type)): + gate_set = set([gate_set]) + + if isinstance(gate_set, Iterable): + gate_types = tuple(gate for gate in gate_set if isinstance(gate, type)) + gate_names = set(gate for gate in gate_set if isinstance(gate, str)) + self.gate_set = lambda op: (op.name in gate_names) or isinstance(op, gate_types) + else: + self.gate_set = gate_set + + super().__init__() + + def stopping_condition(self, op: qml.operation.Operator) -> bool: + """ """ + pass + + def decompose_operation(self, op: qml.operation.Operator): + """ """ + pass + + def interpret_operation_eqn(self, eqn): + """ """ + pass + + return DynamicDecomposeInterpreter + + +DynamicDecomposeInterpreter = _get_plxpr_dynamic_decompose() + + @partial(transform, plxpr_transform=decompose_plxpr_to_plxpr) def decompose(tape, gate_set=None, max_expansion=None): """Decomposes a quantum circuit into a user-specified gate set. From 1e138fa059862c2ac4cf0b42c7e2ba956ad4832f Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Tue, 21 Jan 2025 16:09:23 -0500 Subject: [PATCH 03/33] Sbattendo la testa contro il muro tante volte --- pennylane/transforms/decompose.py | 66 +++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index c53d86a6f86..3bb3a69ca2c 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -218,17 +218,69 @@ def __init__(self, gate_set=None, max_expansion=None): super().__init__() - def stopping_condition(self, op: qml.operation.Operator) -> bool: - """ """ - pass - - def decompose_operation(self, op: qml.operation.Operator): - """ """ + def eval_dynamic_decomposition(self, op): pass def interpret_operation_eqn(self, eqn): """ """ - pass + + from jax import make_jaxpr + + # eq: e' una roba del tipo: + # _:AbstractOperator() = Hadamard[n_wires=1] 0 + + invals = (self.read(invar) for invar in eqn.invars) + with qml.QueuingManager.stop_recording(): + op = eqn.primitive.impl(*invals, **eqn.params) + + # Ora abbiamo recuperato l'operatore concreto + print(f"op from DynamicDecomposeInterpreter: {op}") + + if hasattr(op, "_compute_plxpr_decomposition"): + + # Non sono sicuro di cosa stia facendo (questo argomento forse non ha senso) + args = op.wires + + jaxpr_decomp = make_jaxpr(op._compute_plxpr_decomposition)(*op.wires) + + # Adesso abbiamo catturato la decomposizione dinamica dell'operatore. + # Dobbiamo ora interpretare il jaxpr della decomposizione dinamica, ma senza chiamare + # di nuovo la funzione 'read' nella classe base per evitare di tornare a questo punto + + for arg, invar in zip(args, jaxpr_decomp.jaxpr.invars, strict=True): + self._env[invar] = arg + for const, constvar in zip( + jaxpr_decomp.consts, jaxpr_decomp.jaxpr.constvars, strict=True + ): + self._env[constvar] = const + + for inner_eqn in jaxpr_decomp.eqns: + + custom_handler = self._primitive_registrations.get(inner_eqn.primitive, None) + + if custom_handler: + invals = [self.read(invar) for invar in inner_eqn.invars] + outvals = custom_handler(self, *invals, **inner_eqn.params) + + elif isinstance(inner_eqn.outvars[0].aval, AbstractOperator): + outvals = super().interpret_operation_eqn(inner_eqn) + elif isinstance(inner_eqn.outvars[0].aval, AbstractMeasurement): + outvals = super().interpret_measurement_eqn(inner_eqn) + else: + invals = [self.read(invar) for invar in inner_eqn.invars] + outvals = inner_eqn.primitive.bind(*invals, **inner_eqn.params) + + if not inner_eqn.primitive.multiple_results: + outvals = [outvals] + + for inner_outvar, inner_outval in zip(inner_eqn.outvars, outvals, strict=True): + self._env[inner_outvar] = inner_outval + + else: + + # Se l'operatore non ha un metodo per calcolare la decomposizione dinamica, + # chiamiamo semplicemente la funzione di interpretazione base + return super().interpret_operation_eqn(eqn) return DynamicDecomposeInterpreter From 76c925044b2b01a25a4e3cd0eba38af55e8b0373 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Wed, 22 Jan 2025 11:14:02 -0500 Subject: [PATCH 04/33] Current prototype version --- pennylane/transforms/decompose.py | 72 +++++++++---------- .../transforms/test_capture_decompose.py | 55 +++++++++++++- 2 files changed, 89 insertions(+), 38 deletions(-) diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index 3bb3a69ca2c..74410636fa3 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -20,9 +20,10 @@ import warnings from collections.abc import Callable, Generator, Iterable from functools import lru_cache, partial -from typing import Optional +from typing import Optional, Sequence import pennylane as qml +from pennylane.capture.primitives import AbstractMeasurement, AbstractOperator from pennylane.transforms.core import transform @@ -218,14 +219,40 @@ def __init__(self, gate_set=None, max_expansion=None): super().__init__() - def eval_dynamic_decomposition(self, op): - pass - - def interpret_operation_eqn(self, eqn): + def eval_dynamic_decomposition( + self, jaxpr_decomp: "jax.core.Jaxpr", consts: Sequence, *args + ): """ """ - from jax import make_jaxpr + for arg, invar in zip(args, jaxpr_decomp.invars, strict=True): + self._env[invar] = arg + for const, constvar in zip(consts, jaxpr_decomp.constvars, strict=True): + self._env[constvar] = const + + for inner_eqn in jaxpr_decomp.eqns: + + custom_handler = self._primitive_registrations.get(inner_eqn.primitive, None) + + if custom_handler: + invals = [self.read(invar) for invar in inner_eqn.invars] + outvals = custom_handler(self, *invals, **inner_eqn.params) + elif isinstance(inner_eqn.outvars[0].aval, AbstractOperator): + outvals = super().interpret_operation_eqn(inner_eqn) + elif isinstance(inner_eqn.outvars[0].aval, AbstractMeasurement): + outvals = super().interpret_measurement_eqn(inner_eqn) + else: + invals = [self.read(invar) for invar in inner_eqn.invars] + outvals = inner_eqn.primitive.bind(*invals, **inner_eqn.params) + + if not inner_eqn.primitive.multiple_results: + outvals = [outvals] + + for inner_outvar, inner_outval in zip(inner_eqn.outvars, outvals, strict=True): + self._env[inner_outvar] = inner_outval + + def interpret_operation_eqn(self, eqn): + """ """ # eq: e' una roba del tipo: # _:AbstractOperator() = Hadamard[n_wires=1] 0 @@ -234,47 +261,18 @@ def interpret_operation_eqn(self, eqn): op = eqn.primitive.impl(*invals, **eqn.params) # Ora abbiamo recuperato l'operatore concreto - print(f"op from DynamicDecomposeInterpreter: {op}") if hasattr(op, "_compute_plxpr_decomposition"): # Non sono sicuro di cosa stia facendo (questo argomento forse non ha senso) args = op.wires - jaxpr_decomp = make_jaxpr(op._compute_plxpr_decomposition)(*op.wires) + jaxpr_decomp = make_jaxpr(op._compute_plxpr_decomposition)(*args) # Adesso abbiamo catturato la decomposizione dinamica dell'operatore. # Dobbiamo ora interpretare il jaxpr della decomposizione dinamica, ma senza chiamare # di nuovo la funzione 'read' nella classe base per evitare di tornare a questo punto - - for arg, invar in zip(args, jaxpr_decomp.jaxpr.invars, strict=True): - self._env[invar] = arg - for const, constvar in zip( - jaxpr_decomp.consts, jaxpr_decomp.jaxpr.constvars, strict=True - ): - self._env[constvar] = const - - for inner_eqn in jaxpr_decomp.eqns: - - custom_handler = self._primitive_registrations.get(inner_eqn.primitive, None) - - if custom_handler: - invals = [self.read(invar) for invar in inner_eqn.invars] - outvals = custom_handler(self, *invals, **inner_eqn.params) - - elif isinstance(inner_eqn.outvars[0].aval, AbstractOperator): - outvals = super().interpret_operation_eqn(inner_eqn) - elif isinstance(inner_eqn.outvars[0].aval, AbstractMeasurement): - outvals = super().interpret_measurement_eqn(inner_eqn) - else: - invals = [self.read(invar) for invar in inner_eqn.invars] - outvals = inner_eqn.primitive.bind(*invals, **inner_eqn.params) - - if not inner_eqn.primitive.multiple_results: - outvals = [outvals] - - for inner_outvar, inner_outval in zip(inner_eqn.outvars, outvals, strict=True): - self._env[inner_outvar] = inner_outval + self.eval_dynamic_decomposition(jaxpr_decomp.jaxpr, jaxpr_decomp.consts, *args) else: diff --git a/tests/capture/transforms/test_capture_decompose.py b/tests/capture/transforms/test_capture_decompose.py index e2d834f43a8..50fdfaaed8e 100644 --- a/tests/capture/transforms/test_capture_decompose.py +++ b/tests/capture/transforms/test_capture_decompose.py @@ -28,7 +28,12 @@ qnode_prim, while_loop_prim, ) -from pennylane.transforms.decompose import DecomposeInterpreter, decompose_plxpr_to_plxpr +from pennylane.operation import Operation +from pennylane.transforms.decompose import ( + DecomposeInterpreter, + DynamicDecomposeInterpreter, + decompose_plxpr_to_plxpr, +) pytestmark = [pytest.mark.jax, pytest.mark.usefixtures("enable_disable_plxpr")] @@ -499,3 +504,51 @@ def circuit(x, y, z): assert transformed_jaxpr.eqns[2].primitive == qml.RZ._primitive assert transformed_jaxpr.eqns[3].primitive == qml.PauliZ._primitive assert transformed_jaxpr.eqns[4].primitive == qml.measurements.ExpectationMP._obs_primitive + + +class SimpleCustomOp(Operation): + num_wires = 1 + num_params = 0 + ndim_params = (0,) + basis = "Z" + grad_method = "A" + parameter_frequencies = [(1,)] + + def _init__(self, wires, id=None): + super().__init__(wires=wires, id=id) + + @staticmethod + def _compute_plxpr_decomposition(wires): + qml.RX(0.5, wires=wires) + + +class TestDynamicDecomposeInterpreter: + + def test_function_simple(self): + """ """ + + @DynamicDecomposeInterpreter() + def f(x): + qml.RY(x, wires=0) + SimpleCustomOp(wires=0) + return qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(f)(0.5) + assert jaxpr.eqns[0].primitive == qml.RY._primitive + assert jaxpr.eqns[1].primitive == qml.RX._primitive + + def test_qnode_simple(self): + """ """ + + @DynamicDecomposeInterpreter() + @qml.qnode(device=qml.device("default.qubit", wires=2)) + def circuit(x): + qml.RY(x, wires=0) + SimpleCustomOp(wires=0) + return qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(circuit)(0.5) + assert jaxpr.eqns[0].primitive == qnode_prim + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[0].primitive == qml.RY._primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.RX._primitive From cee6ec4f3cbec9dd260ca1ce0e288905a744658f Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Wed, 22 Jan 2025 17:14:00 -0500 Subject: [PATCH 05/33] Fixing one more problem --- pennylane/transforms/decompose.py | 40 ++++++++++--------- .../transforms/test_capture_decompose.py | 29 ++++++++++++-- 2 files changed, 47 insertions(+), 22 deletions(-) diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index 74410636fa3..a936aaa6a1b 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -199,7 +199,10 @@ def _get_plxpr_dynamic_decompose(): # pylint: disable=missing-docstring # pylint: disable=redefined-outer-name class DynamicDecomposeInterpreter(qml.capture.PlxprInterpreter): - """ """ + """ + Experimental Plxpr Interpreter for applying a dynamic decomposition to operations program capture is enabled. + + """ def __init__(self, gate_set=None, max_expansion=None): self.max_expansion = max_expansion @@ -222,7 +225,15 @@ def __init__(self, gate_set=None, max_expansion=None): def eval_dynamic_decomposition( self, jaxpr_decomp: "jax.core.Jaxpr", consts: Sequence, *args ): - """ """ + """ + Evaluate a dynamic decomposition of a Jaxpr. + + Args: + jaxpr_decomp (jax.core.Jaxpr): the Jaxpr to evaluate + consts (Sequence): the constants to use in the evaluation + *args: the arguments to use in the evaluation + + """ for arg, invar in zip(args, jaxpr_decomp.invars, strict=True): self._env[invar] = arg @@ -251,33 +262,26 @@ def eval_dynamic_decomposition( for inner_outvar, inner_outval in zip(inner_eqn.outvars, outvals, strict=True): self._env[inner_outvar] = inner_outval - def interpret_operation_eqn(self, eqn): - """ """ - # eq: e' una roba del tipo: - # _:AbstractOperator() = Hadamard[n_wires=1] 0 + def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): + """ + Interpret an equation corresponding to an operator. + + Args: + eqn (jax.core.JaxprEqn): a jax equation for an operator. + """ invals = (self.read(invar) for invar in eqn.invars) with qml.QueuingManager.stop_recording(): op = eqn.primitive.impl(*invals, **eqn.params) - # Ora abbiamo recuperato l'operatore concreto - if hasattr(op, "_compute_plxpr_decomposition"): - # Non sono sicuro di cosa stia facendo (questo argomento forse non ha senso) - args = op.wires - - jaxpr_decomp = make_jaxpr(op._compute_plxpr_decomposition)(*args) - - # Adesso abbiamo catturato la decomposizione dinamica dell'operatore. - # Dobbiamo ora interpretare il jaxpr della decomposizione dinamica, ma senza chiamare - # di nuovo la funzione 'read' nella classe base per evitare di tornare a questo punto + jaxpr_decomp = op._plxpr_decomposition() + args = (*op.parameters, tuple(op.wires), *op.hyperparameters) self.eval_dynamic_decomposition(jaxpr_decomp.jaxpr, jaxpr_decomp.consts, *args) else: - # Se l'operatore non ha un metodo per calcolare la decomposizione dinamica, - # chiamiamo semplicemente la funzione di interpretazione base return super().interpret_operation_eqn(eqn) return DynamicDecomposeInterpreter diff --git a/tests/capture/transforms/test_capture_decompose.py b/tests/capture/transforms/test_capture_decompose.py index 50fdfaaed8e..f681b267c72 100644 --- a/tests/capture/transforms/test_capture_decompose.py +++ b/tests/capture/transforms/test_capture_decompose.py @@ -509,10 +509,6 @@ def circuit(x, y, z): class SimpleCustomOp(Operation): num_wires = 1 num_params = 0 - ndim_params = (0,) - basis = "Z" - grad_method = "A" - parameter_frequencies = [(1,)] def _init__(self, wires, id=None): super().__init__(wires=wires, id=id) @@ -522,6 +518,31 @@ def _compute_plxpr_decomposition(wires): qml.RX(0.5, wires=wires) +class CustomOpCond(Operation): + num_wires = 1 + num_params = 1 + + def __init__(self, phi, wires, id=None): + super().__init__(phi, wires=wires, id=id) + + def _plxpr_decomposition(self) -> "jax.core.Jaxpr": + + return jax.make_jaxpr(self._compute_plxpr_decomposition)( + *self.parameters, wires=tuple(self.wires), **self.hyperparameters + ) + + @staticmethod + def _compute_plxpr_decomposition(phi, wires): + + def true_fn(phi): + qml.RX(phi, wires=0) + + def false_fn(phi): + qml.RY(phi, wires=0) + + qml.cond(phi > 0.5, true_fn, false_fn)(phi) + + class TestDynamicDecomposeInterpreter: def test_function_simple(self): From d16a9e099d4a1371e830b7fb3300c27f12a929ee Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Thu, 23 Jan 2025 10:18:04 -0500 Subject: [PATCH 06/33] Moving tests to separate file --- .../transforms/test_capture_decompose.py | 71 ------ .../test_capture_dynamic_decompositions.py | 229 ++++++++++++++++++ 2 files changed, 229 insertions(+), 71 deletions(-) create mode 100644 tests/capture/transforms/test_capture_dynamic_decompositions.py diff --git a/tests/capture/transforms/test_capture_decompose.py b/tests/capture/transforms/test_capture_decompose.py index f681b267c72..89325a6427f 100644 --- a/tests/capture/transforms/test_capture_decompose.py +++ b/tests/capture/transforms/test_capture_decompose.py @@ -28,10 +28,8 @@ qnode_prim, while_loop_prim, ) -from pennylane.operation import Operation from pennylane.transforms.decompose import ( DecomposeInterpreter, - DynamicDecomposeInterpreter, decompose_plxpr_to_plxpr, ) @@ -504,72 +502,3 @@ def circuit(x, y, z): assert transformed_jaxpr.eqns[2].primitive == qml.RZ._primitive assert transformed_jaxpr.eqns[3].primitive == qml.PauliZ._primitive assert transformed_jaxpr.eqns[4].primitive == qml.measurements.ExpectationMP._obs_primitive - - -class SimpleCustomOp(Operation): - num_wires = 1 - num_params = 0 - - def _init__(self, wires, id=None): - super().__init__(wires=wires, id=id) - - @staticmethod - def _compute_plxpr_decomposition(wires): - qml.RX(0.5, wires=wires) - - -class CustomOpCond(Operation): - num_wires = 1 - num_params = 1 - - def __init__(self, phi, wires, id=None): - super().__init__(phi, wires=wires, id=id) - - def _plxpr_decomposition(self) -> "jax.core.Jaxpr": - - return jax.make_jaxpr(self._compute_plxpr_decomposition)( - *self.parameters, wires=tuple(self.wires), **self.hyperparameters - ) - - @staticmethod - def _compute_plxpr_decomposition(phi, wires): - - def true_fn(phi): - qml.RX(phi, wires=0) - - def false_fn(phi): - qml.RY(phi, wires=0) - - qml.cond(phi > 0.5, true_fn, false_fn)(phi) - - -class TestDynamicDecomposeInterpreter: - - def test_function_simple(self): - """ """ - - @DynamicDecomposeInterpreter() - def f(x): - qml.RY(x, wires=0) - SimpleCustomOp(wires=0) - return qml.expval(qml.Z(0)) - - jaxpr = jax.make_jaxpr(f)(0.5) - assert jaxpr.eqns[0].primitive == qml.RY._primitive - assert jaxpr.eqns[1].primitive == qml.RX._primitive - - def test_qnode_simple(self): - """ """ - - @DynamicDecomposeInterpreter() - @qml.qnode(device=qml.device("default.qubit", wires=2)) - def circuit(x): - qml.RY(x, wires=0) - SimpleCustomOp(wires=0) - return qml.expval(qml.Z(0)) - - jaxpr = jax.make_jaxpr(circuit)(0.5) - assert jaxpr.eqns[0].primitive == qnode_prim - qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] - assert qfunc_jaxpr.eqns[0].primitive == qml.RY._primitive - assert qfunc_jaxpr.eqns[1].primitive == qml.RX._primitive diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py new file mode 100644 index 00000000000..8da6e810cdc --- /dev/null +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -0,0 +1,229 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the ``DecomposeInterpreter`` class""" +# pylint:disable=protected-access,unused-argument, wrong-import-position +import pytest + +import pennylane as qml + +jax = pytest.importorskip("jax") + +from pennylane.capture.primitives import ( + adjoint_transform_prim, + cond_prim, + for_loop_prim, + grad_prim, + jacobian_prim, + qnode_prim, + while_loop_prim, +) +from pennylane.operation import Operation +from pennylane.transforms.decompose import ( + DynamicDecomposeInterpreter, +) + +pytestmark = [pytest.mark.jax, pytest.mark.usefixtures("enable_disable_plxpr")] + + +class SimpleCustomOp(Operation): + num_wires = 1 + num_params = 0 + + def _init__(self, wires, id=None): + super().__init__(wires=wires, id=id) + + def _plxpr_decomposition(self) -> "jax.core.Jaxpr": + + return jax.make_jaxpr(self._compute_plxpr_decomposition)( + *self.parameters, wires=tuple(self.wires), **self.hyperparameters + ) + + @staticmethod + def _compute_plxpr_decomposition(wires): + qml.RX(0.5, wires=0) + + +class CustomOpCond(Operation): + num_wires = 1 + num_params = 1 + + def __init__(self, phi, wires, id=None): + super().__init__(phi, wires=wires, id=id) + + def _plxpr_decomposition(self) -> "jax.core.Jaxpr": + + return jax.make_jaxpr(self._compute_plxpr_decomposition)( + *self.parameters, wires=tuple(self.wires), **self.hyperparameters + ) + + @staticmethod + def _compute_plxpr_decomposition(phi, wires): + + def true_fn(phi): + qml.RX(phi, wires=0) + + def false_fn(phi): + qml.RY(phi, wires=0) + + qml.cond(phi > 0.5, true_fn, false_fn)(phi) + + +class CustomOpForLoop(Operation): + num_wires = 1 + num_params = 1 + + def __init__(self, phi, wires, id=None): + super().__init__(phi, wires=wires, id=id) + + def _plxpr_decomposition(self) -> "jax.core.Jaxpr": + + return jax.make_jaxpr(self._compute_plxpr_decomposition)( + *self.parameters, wires=tuple(self.wires), **self.hyperparameters + ) + + @staticmethod + def _compute_plxpr_decomposition(phi, wires): + + @qml.for_loop(0, 3, 1) + def loop_rx(i, phi): + qml.RX(phi, wires=0) + return jax.numpy.sin(phi) + + final_x = loop_rx(phi) + + return qml.expval(qml.Z(0)) + + +class CustomOpWhileLoop(Operation): + num_wires = 1 + num_params = 1 + + def __init__(self, phi, wires, id=None): + super().__init__(phi, wires=wires, id=id) + + def _plxpr_decomposition(self) -> "jax.core.Jaxpr": + + return jax.make_jaxpr(self._compute_plxpr_decomposition)( + *self.parameters, wires=tuple(self.wires), **self.hyperparameters + ) + + @staticmethod + def _compute_plxpr_decomposition(phi, wires): + + @qml.while_loop(lambda i: i < 3) + def loop_rx(phi): + qml.RX(phi, wires=0) + return jax.numpy.sin(phi) + + loop_rx(phi) + + return qml.expval(qml.Z(0)) + + +class TestDynamicDecomposeInterpreter: + + def test_function_simple(self): + """ """ + + @DynamicDecomposeInterpreter() + def f(x): + qml.RY(x, wires=0) + SimpleCustomOp(wires=0) + return qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(f)(0.5) + assert jaxpr.eqns[0].primitive == qml.RY._primitive + assert jaxpr.eqns[1].primitive == qml.RX._primitive + + ############################ + ### QNode + ############################ + + def test_qnode_simple(self): + """ """ + + @DynamicDecomposeInterpreter() + @qml.qnode(device=qml.device("default.qubit", wires=2)) + def circuit(x): + qml.RY(x, wires=0) + SimpleCustomOp(wires=0) + return qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(circuit)(0.5) + + assert jaxpr.eqns[0].primitive == qnode_prim + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[0].primitive == qml.RY._primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[2].primitive == qml.PauliZ._primitive + assert qfunc_jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive + + def test_qnode_cond(self): + """ """ + + @DynamicDecomposeInterpreter() + @qml.qnode(device=qml.device("default.qubit", wires=2)) + def f(x): + CustomOpCond(x, wires=0) + return qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(f)(0.5) + + assert jaxpr.eqns[0].primitive == qnode_prim + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[1].primitive == cond_prim + assert ( + qfunc_jaxpr.eqns[1].params["jaxpr_branches"][0].eqns[0].primitive == qml.RX._primitive + ) + assert ( + qfunc_jaxpr.eqns[1].params["jaxpr_branches"][1].eqns[0].primitive == qml.RY._primitive + ) + assert qfunc_jaxpr.eqns[2].primitive == qml.PauliZ._primitive + assert qfunc_jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive + + def test_qnode_for_loop(self): + """ """ + + @DynamicDecomposeInterpreter() + @qml.qnode(device=qml.device("default.qubit", wires=2)) + def f(x): + CustomOpForLoop(x, wires=0) + return qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(f)(0.5) + + assert jaxpr.eqns[0].primitive == qnode_prim + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[0].primitive == for_loop_prim + assert qfunc_jaxpr.eqns[0].params["jaxpr_body_fn"].eqns[0].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.PauliZ._primitive + assert qfunc_jaxpr.eqns[2].primitive == qml.measurements.ExpectationMP._obs_primitive + + def test_qnode_while_loop(self): + """ """ + + @DynamicDecomposeInterpreter() + @qml.qnode(device=qml.device("default.qubit", wires=2)) + def f(x): + CustomOpWhileLoop(x, wires=0) + return qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(f)(0.5) + + assert jaxpr.eqns[0].primitive == qnode_prim + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[0].primitive == while_loop_prim + assert qfunc_jaxpr.eqns[0].params["jaxpr_body_fn"].eqns[0].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.PauliZ._primitive + assert qfunc_jaxpr.eqns[2].primitive == qml.measurements.ExpectationMP._obs_primitive From 18bc43c34edbe6ea7c95a706b0417a8faf8f6074 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Thu, 23 Jan 2025 10:34:35 -0500 Subject: [PATCH 07/33] Pylint fixes (although premature) --- pennylane/transforms/decompose.py | 11 ++++---- .../test_capture_dynamic_decompositions.py | 27 ++++++++++++------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index a936aaa6a1b..5f95eb221a0 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -192,7 +192,8 @@ def wrapper(*inner_args): def _get_plxpr_dynamic_decompose(): # pylint: disable=missing-docstring try: # pylint: disable=import-outside-toplevel - from jax import make_jaxpr + # pylint: disable=unused-import + import jax except ImportError: # pragma: no cover return None, None @@ -278,11 +279,11 @@ def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): jaxpr_decomp = op._plxpr_decomposition() args = (*op.parameters, tuple(op.wires), *op.hyperparameters) - self.eval_dynamic_decomposition(jaxpr_decomp.jaxpr, jaxpr_decomp.consts, *args) - - else: + return self.eval_dynamic_decomposition( + jaxpr_decomp.jaxpr, jaxpr_decomp.consts, *args + ) - return super().interpret_operation_eqn(eqn) + return super().interpret_operation_eqn(eqn) return DynamicDecomposeInterpreter diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index 8da6e810cdc..70eaa0b8d45 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Unit tests for the ``DecomposeInterpreter`` class""" -# pylint:disable=protected-access,unused-argument, wrong-import-position +# pylint:disable=protected-access,unused-argument, wrong-import-position, no-value-for-parameter, import pytest import pennylane as qml @@ -20,11 +20,8 @@ jax = pytest.importorskip("jax") from pennylane.capture.primitives import ( - adjoint_transform_prim, cond_prim, for_loop_prim, - grad_prim, - jacobian_prim, qnode_prim, while_loop_prim, ) @@ -37,6 +34,8 @@ class SimpleCustomOp(Operation): + """Simple custom operation that contains a single gate in its decomposition""" + num_wires = 1 num_params = 0 @@ -55,6 +54,8 @@ def _compute_plxpr_decomposition(wires): class CustomOpCond(Operation): + """Custom operation that contains a conditional in its decomposition""" + num_wires = 1 num_params = 1 @@ -80,6 +81,8 @@ def false_fn(phi): class CustomOpForLoop(Operation): + """Custom operation that contains a for loop in its decomposition""" + num_wires = 1 num_params = 1 @@ -100,12 +103,15 @@ def loop_rx(i, phi): qml.RX(phi, wires=0) return jax.numpy.sin(phi) - final_x = loop_rx(phi) + # pylint: disable=unused-variable + loop_rx(phi) return qml.expval(qml.Z(0)) class CustomOpWhileLoop(Operation): + """Custom operation that contains a while loop in its decomposition""" + num_wires = 1 num_params = 1 @@ -132,9 +138,10 @@ def loop_rx(phi): class TestDynamicDecomposeInterpreter: + """Tests for the DynamicDecomposeInterpreter class""" def test_function_simple(self): - """ """ + """Test that a function with a custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() def f(x): @@ -151,7 +158,7 @@ def f(x): ############################ def test_qnode_simple(self): - """ """ + """Test that a QNode with a custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() @qml.qnode(device=qml.device("default.qubit", wires=2)) @@ -170,7 +177,7 @@ def circuit(x): assert qfunc_jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive def test_qnode_cond(self): - """ """ + """Test that a QNode with a conditional custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() @qml.qnode(device=qml.device("default.qubit", wires=2)) @@ -193,7 +200,7 @@ def f(x): assert qfunc_jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive def test_qnode_for_loop(self): - """ """ + """Test that a QNode with a for loop custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() @qml.qnode(device=qml.device("default.qubit", wires=2)) @@ -211,7 +218,7 @@ def f(x): assert qfunc_jaxpr.eqns[2].primitive == qml.measurements.ExpectationMP._obs_primitive def test_qnode_while_loop(self): - """ """ + """Test that a QNode with a while loop custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() @qml.qnode(device=qml.device("default.qubit", wires=2)) From e2e8fd075bc2940b4ccb9994840ac5e3e0bf88b9 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Thu, 23 Jan 2025 11:24:18 -0500 Subject: [PATCH 08/33] Removing reundandt tuple calls --- pennylane/transforms/decompose.py | 2 +- .../test_capture_dynamic_decompositions.py | 22 +++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index 5f95eb221a0..6e61d823072 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -278,7 +278,7 @@ def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): if hasattr(op, "_compute_plxpr_decomposition"): jaxpr_decomp = op._plxpr_decomposition() - args = (*op.parameters, tuple(op.wires), *op.hyperparameters) + args = (*op.parameters, *op.wires, *op.hyperparameters) return self.eval_dynamic_decomposition( jaxpr_decomp.jaxpr, jaxpr_decomp.consts, *args ) diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index 70eaa0b8d45..eb35f99ebb4 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -45,12 +45,12 @@ def _init__(self, wires, id=None): def _plxpr_decomposition(self) -> "jax.core.Jaxpr": return jax.make_jaxpr(self._compute_plxpr_decomposition)( - *self.parameters, wires=tuple(self.wires), **self.hyperparameters + *self.parameters, *self.wires, **self.hyperparameters ) @staticmethod def _compute_plxpr_decomposition(wires): - qml.RX(0.5, wires=0) + qml.Hadamard(wires=wires) class CustomOpCond(Operation): @@ -65,17 +65,17 @@ def __init__(self, phi, wires, id=None): def _plxpr_decomposition(self) -> "jax.core.Jaxpr": return jax.make_jaxpr(self._compute_plxpr_decomposition)( - *self.parameters, wires=tuple(self.wires), **self.hyperparameters + *self.parameters, *self.wires, **self.hyperparameters ) @staticmethod def _compute_plxpr_decomposition(phi, wires): def true_fn(phi): - qml.RX(phi, wires=0) + qml.RX(phi, wires=wires) def false_fn(phi): - qml.RY(phi, wires=0) + qml.RY(phi, wires=wires) qml.cond(phi > 0.5, true_fn, false_fn)(phi) @@ -92,7 +92,7 @@ def __init__(self, phi, wires, id=None): def _plxpr_decomposition(self) -> "jax.core.Jaxpr": return jax.make_jaxpr(self._compute_plxpr_decomposition)( - *self.parameters, wires=tuple(self.wires), **self.hyperparameters + *self.parameters, *self.wires, **self.hyperparameters ) @staticmethod @@ -100,7 +100,7 @@ def _compute_plxpr_decomposition(phi, wires): @qml.for_loop(0, 3, 1) def loop_rx(i, phi): - qml.RX(phi, wires=0) + qml.RX(phi, wires=wires) return jax.numpy.sin(phi) # pylint: disable=unused-variable @@ -121,7 +121,7 @@ def __init__(self, phi, wires, id=None): def _plxpr_decomposition(self) -> "jax.core.Jaxpr": return jax.make_jaxpr(self._compute_plxpr_decomposition)( - *self.parameters, wires=tuple(self.wires), **self.hyperparameters + *self.parameters, *self.wires, **self.hyperparameters ) @staticmethod @@ -129,7 +129,7 @@ def _compute_plxpr_decomposition(phi, wires): @qml.while_loop(lambda i: i < 3) def loop_rx(phi): - qml.RX(phi, wires=0) + qml.RX(phi, wires=wires) return jax.numpy.sin(phi) loop_rx(phi) @@ -151,7 +151,7 @@ def f(x): jaxpr = jax.make_jaxpr(f)(0.5) assert jaxpr.eqns[0].primitive == qml.RY._primitive - assert jaxpr.eqns[1].primitive == qml.RX._primitive + assert jaxpr.eqns[1].primitive == qml.Hadamard._primitive ############################ ### QNode @@ -172,7 +172,7 @@ def circuit(x): assert jaxpr.eqns[0].primitive == qnode_prim qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] assert qfunc_jaxpr.eqns[0].primitive == qml.RY._primitive - assert qfunc_jaxpr.eqns[1].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.Hadamard._primitive assert qfunc_jaxpr.eqns[2].primitive == qml.PauliZ._primitive assert qfunc_jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive From 0abd620c9212d2e6d2ed6e0922a7973934068385 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Thu, 23 Jan 2025 14:30:36 -0500 Subject: [PATCH 09/33] Tests with dynamic wires --- .../test_capture_dynamic_decompositions.py | 122 +++++++++++++----- 1 file changed, 93 insertions(+), 29 deletions(-) diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index eb35f99ebb4..a57eacc5c3f 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -37,10 +37,10 @@ class SimpleCustomOp(Operation): """Simple custom operation that contains a single gate in its decomposition""" num_wires = 1 - num_params = 0 + num_params = 1 - def _init__(self, wires, id=None): - super().__init__(wires=wires, id=id) + def _init__(self, phi, wires, id=None): + super().__init__(phi, wires=wires, id=id) def _plxpr_decomposition(self) -> "jax.core.Jaxpr": @@ -49,8 +49,8 @@ def _plxpr_decomposition(self) -> "jax.core.Jaxpr": ) @staticmethod - def _compute_plxpr_decomposition(wires): - qml.Hadamard(wires=wires) + def _compute_plxpr_decomposition(phi, wires): + qml.RX(phi, wires=wires) class CustomOpCond(Operation): @@ -71,13 +71,13 @@ def _plxpr_decomposition(self) -> "jax.core.Jaxpr": @staticmethod def _compute_plxpr_decomposition(phi, wires): - def true_fn(phi): + def true_fn(phi, wires): qml.RX(phi, wires=wires) - def false_fn(phi): + def false_fn(phi, wires): qml.RY(phi, wires=wires) - qml.cond(phi > 0.5, true_fn, false_fn)(phi) + qml.cond(phi > 0.5, true_fn, false_fn)(phi, wires) class CustomOpForLoop(Operation): @@ -128,11 +128,11 @@ def _plxpr_decomposition(self) -> "jax.core.Jaxpr": def _compute_plxpr_decomposition(phi, wires): @qml.while_loop(lambda i: i < 3) - def loop_rx(phi): + def loop_fn(i): qml.RX(phi, wires=wires) - return jax.numpy.sin(phi) + return i + 1 - loop_rx(phi) + _ = loop_fn(0) return qml.expval(qml.Z(0)) @@ -146,46 +146,59 @@ def test_function_simple(self): @DynamicDecomposeInterpreter() def f(x): qml.RY(x, wires=0) - SimpleCustomOp(wires=0) + SimpleCustomOp(x, wires=0) return qml.expval(qml.Z(0)) jaxpr = jax.make_jaxpr(f)(0.5) assert jaxpr.eqns[0].primitive == qml.RY._primitive - assert jaxpr.eqns[1].primitive == qml.Hadamard._primitive + assert jaxpr.eqns[1].primitive == qml.RX._primitive ############################ ### QNode ############################ - def test_qnode_simple(self): + @pytest.mark.parametrize("x", [0.2, 0.8]) + def test_qnode_simple(self, x): """Test that a QNode with a custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() @qml.qnode(device=qml.device("default.qubit", wires=2)) def circuit(x): qml.RY(x, wires=0) - SimpleCustomOp(wires=0) + SimpleCustomOp(x, wires=0) return qml.expval(qml.Z(0)) - jaxpr = jax.make_jaxpr(circuit)(0.5) + jaxpr = jax.make_jaxpr(circuit)(x) assert jaxpr.eqns[0].primitive == qnode_prim qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] assert qfunc_jaxpr.eqns[0].primitive == qml.RY._primitive - assert qfunc_jaxpr.eqns[1].primitive == qml.Hadamard._primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.RX._primitive assert qfunc_jaxpr.eqns[2].primitive == qml.PauliZ._primitive assert qfunc_jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive - def test_qnode_cond(self): + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) + + @qml.qnode(device=qml.device("default.qubit", wires=2)) + def circuit_comparison(x): + qml.RY(x, wires=0) + qml.RX(x, wires=0) + return qml.expval(qml.Z(0)) + + assert jax.numpy.allclose(*result, circuit_comparison(x)) + + @pytest.mark.parametrize("wire", [0, 1]) + @pytest.mark.parametrize("x", [0.2, 0.8]) + def test_qnode_cond(self, x, wire): """Test that a QNode with a conditional custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() @qml.qnode(device=qml.device("default.qubit", wires=2)) - def f(x): - CustomOpCond(x, wires=0) + def circuit(x, wire): + CustomOpCond(x, wires=wire) return qml.expval(qml.Z(0)) - jaxpr = jax.make_jaxpr(f)(0.5) + jaxpr = jax.make_jaxpr(circuit)(x, wire) assert jaxpr.eqns[0].primitive == qnode_prim qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] @@ -199,16 +212,34 @@ def f(x): assert qfunc_jaxpr.eqns[2].primitive == qml.PauliZ._primitive assert qfunc_jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive - def test_qnode_for_loop(self): + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, wire) + + @qml.qnode(device=qml.device("default.qubit", wires=2)) + def circuit_comparison(x, wire): + def true_fn(x, wire): + qml.RX(x, wires=wire) + + def false_fn(x, wire): + qml.RY(x, wires=wire) + + qml.cond(x > 0.5, true_fn, false_fn)(x, wire) + + return qml.expval(qml.Z(0)) + + assert jax.numpy.allclose(*result, circuit_comparison(x, wire)) + + @pytest.mark.parametrize("wire", [0, 1]) + @pytest.mark.parametrize("x", [0.2, 0.8]) + def test_qnode_for_loop(self, x, wire): """Test that a QNode with a for loop custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() @qml.qnode(device=qml.device("default.qubit", wires=2)) - def f(x): - CustomOpForLoop(x, wires=0) + def circuit(x, wire): + CustomOpForLoop(x, wires=wire) return qml.expval(qml.Z(0)) - jaxpr = jax.make_jaxpr(f)(0.5) + jaxpr = jax.make_jaxpr(circuit)(x, wire) assert jaxpr.eqns[0].primitive == qnode_prim qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] @@ -217,16 +248,34 @@ def f(x): assert qfunc_jaxpr.eqns[1].primitive == qml.PauliZ._primitive assert qfunc_jaxpr.eqns[2].primitive == qml.measurements.ExpectationMP._obs_primitive - def test_qnode_while_loop(self): + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, wire) + + @qml.qnode(device=qml.device("default.qubit", wires=2)) + def circuit_comparison(x, wire): + @qml.for_loop(0, 3, 1) + def loop_rx(i, phi): + qml.RX(phi, wires=wire) + return jax.numpy.sin(phi) + + # pylint: disable=unused-variable + loop_rx(x) + + return qml.expval(qml.Z(0)) + + assert jax.numpy.allclose(*result, circuit_comparison(x, wire)) + + @pytest.mark.parametrize("wire", [0, 1]) + @pytest.mark.parametrize("x", [0.2, 0.8]) + def test_qnode_while_loop(self, x, wire): """Test that a QNode with a while loop custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() @qml.qnode(device=qml.device("default.qubit", wires=2)) - def f(x): - CustomOpWhileLoop(x, wires=0) + def circuit(x, wire): + CustomOpWhileLoop(x, wires=wire) return qml.expval(qml.Z(0)) - jaxpr = jax.make_jaxpr(f)(0.5) + jaxpr = jax.make_jaxpr(circuit)(x, wire) assert jaxpr.eqns[0].primitive == qnode_prim qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] @@ -234,3 +283,18 @@ def f(x): assert qfunc_jaxpr.eqns[0].params["jaxpr_body_fn"].eqns[0].primitive == qml.RX._primitive assert qfunc_jaxpr.eqns[1].primitive == qml.PauliZ._primitive assert qfunc_jaxpr.eqns[2].primitive == qml.measurements.ExpectationMP._obs_primitive + + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, wire) + + @qml.qnode(device=qml.device("default.qubit", wires=2)) + def circuit_comparison(x, wire): + @qml.while_loop(lambda i: i < 3) + def loop_fn(i): + qml.RX(x, wires=wire) + return i + 1 + + _ = loop_fn(0) + + return qml.expval(qml.Z(0)) + + assert jax.numpy.allclose(*result, circuit_comparison(x, wire)) From 1ae399ba706891e7cc5dfbf36a16907508fc3831 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Thu, 23 Jan 2025 15:19:16 -0500 Subject: [PATCH 10/33] Adding Autograph test --- .../test_capture_dynamic_decompositions.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index a57eacc5c3f..a01c1377e5b 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -137,6 +137,31 @@ def loop_fn(i): return qml.expval(qml.Z(0)) +class CustomOpAutograph(Operation): + """Custom operation that contains a nested conditional in its decomposition""" + + num_wires = 1 + num_params = 1 + + def __init__(self, phi, wires, id=None): + super().__init__(phi, wires=wires, id=id) + + def _plxpr_decomposition(self) -> "jax.core.Jaxpr": + + return qml.capture.make_plxpr(self._compute_plxpr_decomposition)( + *self.parameters, *self.wires, **self.hyperparameters + ) + + @staticmethod + def _compute_plxpr_decomposition(phi, wires): + + if phi > 0.5: + qml.RX(phi, wires=wires) + + else: + qml.RY(phi, wires=wires) + + class TestDynamicDecomposeInterpreter: """Tests for the DynamicDecomposeInterpreter class""" @@ -298,3 +323,46 @@ def loop_fn(i): return qml.expval(qml.Z(0)) assert jax.numpy.allclose(*result, circuit_comparison(x, wire)) + + @pytest.mark.parametrize("wire", [0, 1]) + @pytest.mark.parametrize("x", [0.2, 0.8]) + def test_qnode_autograph(self, x, wire): + """Test that a QNode with a nested conditional custom operation is correctly decomposed.""" + + @DynamicDecomposeInterpreter() + @qml.qnode(device=qml.device("default.qubit", wires=2)) + def circuit(x, wire): + CustomOpAutograph(x, wires=wire) + return qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(circuit)(x, wire) + + assert jaxpr.eqns[0].primitive == qnode_prim + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[1].primitive == cond_prim + assert ( + qfunc_jaxpr.eqns[1].params["jaxpr_branches"][0].eqns[0].primitive == qml.RX._primitive + ) + assert ( + qfunc_jaxpr.eqns[1].params["jaxpr_branches"][1].eqns[0].primitive == qml.RY._primitive + ) + assert qfunc_jaxpr.eqns[2].primitive == qml.PauliZ._primitive + assert qfunc_jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive + + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, wire) + + @qml.qnode(device=qml.device("default.qubit", wires=2)) + def circuit_comparison(x, wire): + if x > 0.5: + qml.RX(x, wires=wire) + else: + qml.RY(x, wires=wire) + + return qml.expval(qml.Z(0)) + + jaxpr_comparison = qml.capture.make_plxpr(circuit_comparison)(x, wire) + result_comparison = jax.core.eval_jaxpr( + jaxpr_comparison.jaxpr, jaxpr_comparison.consts, x, wire + ) + + assert jax.numpy.allclose(*result, *result_comparison) From c5f2ae596988ab1ccfbe3a053200d33d175863c7 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Fri, 24 Jan 2025 16:40:05 -0500 Subject: [PATCH 11/33] Removing unused parameters and adding a few tests --- pennylane/transforms/decompose.py | 18 ---- .../test_capture_dynamic_decompositions.py | 100 +++++++++++++++--- 2 files changed, 88 insertions(+), 30 deletions(-) diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index 6e61d823072..ffa2a2410d4 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -205,24 +205,6 @@ class DynamicDecomposeInterpreter(qml.capture.PlxprInterpreter): """ - def __init__(self, gate_set=None, max_expansion=None): - self.max_expansion = max_expansion - - if gate_set is None: - gate_set = set(qml.ops.__all__) - - if isinstance(gate_set, (str, type)): - gate_set = set([gate_set]) - - if isinstance(gate_set, Iterable): - gate_types = tuple(gate for gate in gate_set if isinstance(gate, type)) - gate_names = set(gate for gate in gate_set if isinstance(gate, str)) - self.gate_set = lambda op: (op.name in gate_names) or isinstance(op, gate_types) - else: - self.gate_set = gate_set - - super().__init__() - def eval_dynamic_decomposition( self, jaxpr_decomp: "jax.core.Jaxpr", consts: Sequence, *args ): diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index a01c1377e5b..b160232f02a 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -53,6 +53,32 @@ def _compute_plxpr_decomposition(phi, wires): qml.RX(phi, wires=wires) +class CustomOpMultiWire(Operation): + """Custom operation that acts on multiple wires""" + + num_wires = 4 + num_params = 1 + + def __init__(self, phi, wires, id=None): + super().__init__(phi, wires=wires, id=id) + + def _plxpr_decomposition(self) -> "jax.core.Jaxpr": + + return jax.make_jaxpr(self._compute_plxpr_decomposition)( + *self.parameters, *self.wires, **self.hyperparameters + ) + + @staticmethod + def _compute_plxpr_decomposition(phi, *wires): + qml.CNOT([wires[0], wires[1]]) + qml.DoubleExcitation(phi, wires) + qml.CNOT([wires[0], wires[1]]) + qml.RX(phi, wires=wires[0]) + qml.RY(phi, wires=wires[1]) + qml.RZ(phi, wires=wires[2]) + qml.RX(phi, wires=wires[3]) + + class CustomOpCond(Operation): """Custom operation that contains a conditional in its decomposition""" @@ -165,6 +191,17 @@ def _compute_plxpr_decomposition(phi, wires): class TestDynamicDecomposeInterpreter: """Tests for the DynamicDecomposeInterpreter class""" + def test_no_plxpr_decomposition(self): + """Test that a function with a custom operation that does not have a plxpr decomposition is not decomposed.""" + + @DynamicDecomposeInterpreter() + def f(x): + qml.RY(x, wires=0) + + jaxpr = jax.make_jaxpr(f)(0.5) + assert len(jaxpr.eqns) == 1 + assert jaxpr.eqns[0].primitive == qml.RY._primitive + def test_function_simple(self): """Test that a function with a custom operation is correctly decomposed.""" @@ -175,11 +212,14 @@ def f(x): return qml.expval(qml.Z(0)) jaxpr = jax.make_jaxpr(f)(0.5) + assert len(jaxpr.eqns) == 4 assert jaxpr.eqns[0].primitive == qml.RY._primitive assert jaxpr.eqns[1].primitive == qml.RX._primitive + assert jaxpr.eqns[2].primitive == qml.PauliZ._primitive + assert jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive ############################ - ### QNode + ### QNode tests ############################ @pytest.mark.parametrize("x", [0.2, 0.8]) @@ -212,6 +252,43 @@ def circuit_comparison(x): assert jax.numpy.allclose(*result, circuit_comparison(x)) + @pytest.mark.parametrize("wires", [[0, 1, 2, 3], [2, 3, 1, 0]]) + def test_multi_wire(self, wires): + """Test that a QNode with a multi-wire custom operation is correctly decomposed.""" + + @DynamicDecomposeInterpreter() + @qml.qnode(device=qml.device("default.qubit", wires=4)) + def circuit(x, wires): + CustomOpMultiWire(x, wires=wires) + return qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(circuit)(0.5, wires=wires) + + assert jaxpr.eqns[0].primitive == qnode_prim + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[0].primitive == qml.CNOT._primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.DoubleExcitation._primitive + assert qfunc_jaxpr.eqns[2].primitive == qml.CNOT._primitive + assert qfunc_jaxpr.eqns[3].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[4].primitive == qml.RY._primitive + assert qfunc_jaxpr.eqns[5].primitive == qml.RZ._primitive + assert qfunc_jaxpr.eqns[6].primitive == qml.RX._primitive + + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.5, *wires) + + @qml.qnode(device=qml.device("default.qubit", wires=4)) + def circuit_comparison(x, wires): + qml.CNOT([wires[0], wires[1]]) + qml.DoubleExcitation(x, wires) + qml.CNOT([wires[0], wires[1]]) + qml.RX(x, wires=wires[0]) + qml.RY(x, wires=wires[1]) + qml.RZ(x, wires=wires[2]) + qml.RX(x, wires=wires[3]) + return qml.expval(qml.Z(0)) + + assert jax.numpy.allclose(*result, circuit_comparison(0.5, wires)) + @pytest.mark.parametrize("wire", [0, 1]) @pytest.mark.parametrize("x", [0.2, 0.8]) def test_qnode_cond(self, x, wire): @@ -223,7 +300,7 @@ def circuit(x, wire): CustomOpCond(x, wires=wire) return qml.expval(qml.Z(0)) - jaxpr = jax.make_jaxpr(circuit)(x, wire) + jaxpr = jax.make_jaxpr(circuit)(x, wire=wire) assert jaxpr.eqns[0].primitive == qnode_prim qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] @@ -254,8 +331,7 @@ def false_fn(x, wire): assert jax.numpy.allclose(*result, circuit_comparison(x, wire)) @pytest.mark.parametrize("wire", [0, 1]) - @pytest.mark.parametrize("x", [0.2, 0.8]) - def test_qnode_for_loop(self, x, wire): + def test_qnode_for_loop(self, wire): """Test that a QNode with a for loop custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() @@ -264,7 +340,7 @@ def circuit(x, wire): CustomOpForLoop(x, wires=wire) return qml.expval(qml.Z(0)) - jaxpr = jax.make_jaxpr(circuit)(x, wire) + jaxpr = jax.make_jaxpr(circuit)(0.5, wire) assert jaxpr.eqns[0].primitive == qnode_prim qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] @@ -273,7 +349,7 @@ def circuit(x, wire): assert qfunc_jaxpr.eqns[1].primitive == qml.PauliZ._primitive assert qfunc_jaxpr.eqns[2].primitive == qml.measurements.ExpectationMP._obs_primitive - result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, wire) + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.5, wire) @qml.qnode(device=qml.device("default.qubit", wires=2)) def circuit_comparison(x, wire): @@ -287,11 +363,10 @@ def loop_rx(i, phi): return qml.expval(qml.Z(0)) - assert jax.numpy.allclose(*result, circuit_comparison(x, wire)) + assert jax.numpy.allclose(*result, circuit_comparison(0.5, wire)) @pytest.mark.parametrize("wire", [0, 1]) - @pytest.mark.parametrize("x", [0.2, 0.8]) - def test_qnode_while_loop(self, x, wire): + def test_qnode_while_loop(self, wire): """Test that a QNode with a while loop custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() @@ -300,7 +375,7 @@ def circuit(x, wire): CustomOpWhileLoop(x, wires=wire) return qml.expval(qml.Z(0)) - jaxpr = jax.make_jaxpr(circuit)(x, wire) + jaxpr = jax.make_jaxpr(circuit)(0.5, wire) assert jaxpr.eqns[0].primitive == qnode_prim qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] @@ -309,7 +384,7 @@ def circuit(x, wire): assert qfunc_jaxpr.eqns[1].primitive == qml.PauliZ._primitive assert qfunc_jaxpr.eqns[2].primitive == qml.measurements.ExpectationMP._obs_primitive - result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, wire) + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.5, wire) @qml.qnode(device=qml.device("default.qubit", wires=2)) def circuit_comparison(x, wire): @@ -322,7 +397,7 @@ def loop_fn(i): return qml.expval(qml.Z(0)) - assert jax.numpy.allclose(*result, circuit_comparison(x, wire)) + assert jax.numpy.allclose(*result, circuit_comparison(0.5, wire)) @pytest.mark.parametrize("wire", [0, 1]) @pytest.mark.parametrize("x", [0.2, 0.8]) @@ -353,6 +428,7 @@ def circuit(x, wire): @qml.qnode(device=qml.device("default.qubit", wires=2)) def circuit_comparison(x, wire): + if x > 0.5: qml.RX(x, wires=wire) else: From 497440c14129f0b17dcf83667c20b2522e4af0fc Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Fri, 24 Jan 2025 17:00:45 -0500 Subject: [PATCH 12/33] Adding a few more tests --- .../test_capture_dynamic_decompositions.py | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index b160232f02a..7f19e209b77 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -37,7 +37,7 @@ class SimpleCustomOp(Operation): """Simple custom operation that contains a single gate in its decomposition""" num_wires = 1 - num_params = 1 + num_params = 0 def _init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) @@ -49,8 +49,8 @@ def _plxpr_decomposition(self) -> "jax.core.Jaxpr": ) @staticmethod - def _compute_plxpr_decomposition(phi, wires): - qml.RX(phi, wires=wires) + def _compute_plxpr_decomposition(wires): + qml.Hadamard(wires=wires) class CustomOpMultiWire(Operation): @@ -70,9 +70,9 @@ def _plxpr_decomposition(self) -> "jax.core.Jaxpr": @staticmethod def _compute_plxpr_decomposition(phi, *wires): - qml.CNOT([wires[0], wires[1]]) - qml.DoubleExcitation(phi, wires) - qml.CNOT([wires[0], wires[1]]) + qml.CNOT(wires=[wires[0], wires[1]]) + qml.DoubleExcitation(phi, wires=wires) + qml.CNOT(wires=[wires[0], wires[1]]) qml.RX(phi, wires=wires[0]) qml.RY(phi, wires=wires[1]) qml.RZ(phi, wires=wires[2]) @@ -206,15 +206,15 @@ def test_function_simple(self): """Test that a function with a custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() - def f(x): - qml.RY(x, wires=0) - SimpleCustomOp(x, wires=0) + def f(): + qml.RY(0.1, wires=0) + SimpleCustomOp(wires=0) return qml.expval(qml.Z(0)) - jaxpr = jax.make_jaxpr(f)(0.5) + jaxpr = jax.make_jaxpr(f)() assert len(jaxpr.eqns) == 4 assert jaxpr.eqns[0].primitive == qml.RY._primitive - assert jaxpr.eqns[1].primitive == qml.RX._primitive + assert jaxpr.eqns[1].primitive == qml.Hadamard._primitive assert jaxpr.eqns[2].primitive == qml.PauliZ._primitive assert jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive @@ -222,35 +222,34 @@ def f(x): ### QNode tests ############################ - @pytest.mark.parametrize("x", [0.2, 0.8]) - def test_qnode_simple(self, x): + def test_qnode_simple(self): """Test that a QNode with a custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() @qml.qnode(device=qml.device("default.qubit", wires=2)) - def circuit(x): - qml.RY(x, wires=0) - SimpleCustomOp(x, wires=0) + def circuit(): + qml.RY(0.1, wires=0) + SimpleCustomOp(wires=0) return qml.expval(qml.Z(0)) - jaxpr = jax.make_jaxpr(circuit)(x) + jaxpr = jax.make_jaxpr(circuit)() assert jaxpr.eqns[0].primitive == qnode_prim qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] assert qfunc_jaxpr.eqns[0].primitive == qml.RY._primitive - assert qfunc_jaxpr.eqns[1].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.Hadamard._primitive assert qfunc_jaxpr.eqns[2].primitive == qml.PauliZ._primitive assert qfunc_jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive - result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts) @qml.qnode(device=qml.device("default.qubit", wires=2)) - def circuit_comparison(x): - qml.RY(x, wires=0) - qml.RX(x, wires=0) + def circuit_comparison(): + qml.RY(0.1, wires=0) + qml.Hadamard(wires=0) return qml.expval(qml.Z(0)) - assert jax.numpy.allclose(*result, circuit_comparison(x)) + assert jax.numpy.allclose(*result, circuit_comparison()) @pytest.mark.parametrize("wires", [[0, 1, 2, 3], [2, 3, 1, 0]]) def test_multi_wire(self, wires): @@ -436,6 +435,7 @@ def circuit_comparison(x, wire): return qml.expval(qml.Z(0)) + # Autograph requires to capture the function first jaxpr_comparison = qml.capture.make_plxpr(circuit_comparison)(x, wire) result_comparison = jax.core.eval_jaxpr( jaxpr_comparison.jaxpr, jaxpr_comparison.consts, x, wire From c7da133f297f5b4d77ce2834b252e3721f4a76d4 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Fri, 24 Jan 2025 17:05:10 -0500 Subject: [PATCH 13/33] Removing import --- pennylane/transforms/decompose.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index ffa2a2410d4..0c37f9fb422 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -23,7 +23,6 @@ from typing import Optional, Sequence import pennylane as qml -from pennylane.capture.primitives import AbstractMeasurement, AbstractOperator from pennylane.transforms.core import transform @@ -194,6 +193,7 @@ def _get_plxpr_dynamic_decompose(): # pylint: disable=missing-docstring # pylint: disable=import-outside-toplevel # pylint: disable=unused-import import jax + from pennylane.capture.primitives import AbstractMeasurement, AbstractOperator except ImportError: # pragma: no cover return None, None From 2f0417ce139e77dbbf8111628436653e3d1b82bc Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Fri, 24 Jan 2025 17:15:23 -0500 Subject: [PATCH 14/33] Pylint --- tests/capture/transforms/test_capture_decompose.py | 6 ++---- .../transforms/test_capture_dynamic_decompositions.py | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/capture/transforms/test_capture_decompose.py b/tests/capture/transforms/test_capture_decompose.py index 89325a6427f..8494738d0e7 100644 --- a/tests/capture/transforms/test_capture_decompose.py +++ b/tests/capture/transforms/test_capture_decompose.py @@ -28,10 +28,8 @@ qnode_prim, while_loop_prim, ) -from pennylane.transforms.decompose import ( - DecomposeInterpreter, - decompose_plxpr_to_plxpr, -) +from pennylane.transforms.decompose import DecomposeInterpreter, decompose_plxpr_to_plxpr + pytestmark = [pytest.mark.jax, pytest.mark.usefixtures("enable_disable_plxpr")] diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index 7f19e209b77..565aba22278 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for the ``DecomposeInterpreter`` class""" -# pylint:disable=protected-access,unused-argument, wrong-import-position, no-value-for-parameter, +"""Unit tests for the ``DynamicDecomposeInterpreter`` class.""" +# pylint:disable=protected-access,unused-argument, wrong-import-position, no-value-for-parameter, too-few-public-methods import pytest import pennylane as qml From e9ff110081692b6366d8649d8db89226b75624a1 Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Mon, 27 Jan 2025 12:22:05 -0500 Subject: [PATCH 15/33] Adding test with hyperparameters --- pennylane/transforms/decompose.py | 25 ++++++++----------- .../test_capture_dynamic_decompositions.py | 18 ++++++++----- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index 0c37f9fb422..2970d6d8515 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -20,7 +20,7 @@ import warnings from collections.abc import Callable, Generator, Iterable from functools import lru_cache, partial -from typing import Optional, Sequence +from typing import Optional import pennylane as qml from pennylane.transforms.core import transform @@ -193,7 +193,9 @@ def _get_plxpr_dynamic_decompose(): # pylint: disable=missing-docstring # pylint: disable=import-outside-toplevel # pylint: disable=unused-import import jax - from pennylane.capture.primitives import AbstractMeasurement, AbstractOperator + + from pennylane.capture.primitives import (AbstractMeasurement, + AbstractOperator) except ImportError: # pragma: no cover return None, None @@ -202,26 +204,19 @@ def _get_plxpr_dynamic_decompose(): # pylint: disable=missing-docstring class DynamicDecomposeInterpreter(qml.capture.PlxprInterpreter): """ Experimental Plxpr Interpreter for applying a dynamic decomposition to operations program capture is enabled. - """ - def eval_dynamic_decomposition( - self, jaxpr_decomp: "jax.core.Jaxpr", consts: Sequence, *args - ): + def eval_dynamic_decomposition(self, jaxpr_decomp: "jax.core.Jaxpr", *args): """ Evaluate a dynamic decomposition of a Jaxpr. Args: jaxpr_decomp (jax.core.Jaxpr): the Jaxpr to evaluate - consts (Sequence): the constants to use in the evaluation *args: the arguments to use in the evaluation - """ for arg, invar in zip(args, jaxpr_decomp.invars, strict=True): self._env[invar] = arg - for const, constvar in zip(consts, jaxpr_decomp.constvars, strict=True): - self._env[constvar] = const for inner_eqn in jaxpr_decomp.eqns: @@ -232,6 +227,7 @@ def eval_dynamic_decomposition( outvals = custom_handler(self, *invals, **inner_eqn.params) elif isinstance(inner_eqn.outvars[0].aval, AbstractOperator): + # This does not currently support nested decompositions outvals = super().interpret_operation_eqn(inner_eqn) elif isinstance(inner_eqn.outvars[0].aval, AbstractMeasurement): outvals = super().interpret_measurement_eqn(inner_eqn) @@ -260,10 +256,11 @@ def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): if hasattr(op, "_compute_plxpr_decomposition"): jaxpr_decomp = op._plxpr_decomposition() - args = (*op.parameters, *op.wires, *op.hyperparameters) - return self.eval_dynamic_decomposition( - jaxpr_decomp.jaxpr, jaxpr_decomp.consts, *args - ) + args = (*op.parameters, *op.wires, *op.hyperparameters.values()) + + # We assume that the JAXPR of the decomposition does not contain constants + # and that all the required parameters are passed as arguments + return self.eval_dynamic_decomposition(jaxpr_decomp.jaxpr, *args) return super().interpret_operation_eqn(eqn) diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index 565aba22278..d261e9f0ff6 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -60,23 +60,29 @@ class CustomOpMultiWire(Operation): num_params = 1 def __init__(self, phi, wires, id=None): + + self.hyperparameters["key_1"] = 0.1 + self.hyperparameters["key_2"] = 0.2 + super().__init__(phi, wires=wires, id=id) def _plxpr_decomposition(self) -> "jax.core.Jaxpr": return jax.make_jaxpr(self._compute_plxpr_decomposition)( - *self.parameters, *self.wires, **self.hyperparameters + *self.parameters, *self.wires, *self.hyperparameters.values() ) @staticmethod - def _compute_plxpr_decomposition(phi, *wires): + def _compute_plxpr_decomposition(phi, *args): + wires = args[:4] + hyperparameters = args[4:] qml.CNOT(wires=[wires[0], wires[1]]) qml.DoubleExcitation(phi, wires=wires) qml.CNOT(wires=[wires[0], wires[1]]) - qml.RX(phi, wires=wires[0]) + qml.RX(hyperparameters[0], wires=wires[0]) qml.RY(phi, wires=wires[1]) qml.RZ(phi, wires=wires[2]) - qml.RX(phi, wires=wires[3]) + qml.RX(hyperparameters[1], wires=wires[3]) class CustomOpCond(Operation): @@ -280,10 +286,10 @@ def circuit_comparison(x, wires): qml.CNOT([wires[0], wires[1]]) qml.DoubleExcitation(x, wires) qml.CNOT([wires[0], wires[1]]) - qml.RX(x, wires=wires[0]) + qml.RX(0.1, wires=wires[0]) qml.RY(x, wires=wires[1]) qml.RZ(x, wires=wires[2]) - qml.RX(x, wires=wires[3]) + qml.RX(0.2, wires=wires[3]) return qml.expval(qml.Z(0)) assert jax.numpy.allclose(*result, circuit_comparison(0.5, wires)) From f2437b4bc85222ce7c91ebf269412a416963787a Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Mon, 27 Jan 2025 12:26:01 -0500 Subject: [PATCH 16/33] Black --- pennylane/transforms/decompose.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index 2970d6d8515..261086d7f78 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -194,8 +194,7 @@ def _get_plxpr_dynamic_decompose(): # pylint: disable=missing-docstring # pylint: disable=unused-import import jax - from pennylane.capture.primitives import (AbstractMeasurement, - AbstractOperator) + from pennylane.capture.primitives import AbstractMeasurement, AbstractOperator except ImportError: # pragma: no cover return None, None From 8c615c5940cab7e508db1cc4424c00bb341fbadf Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 27 Jan 2025 14:05:15 -0500 Subject: [PATCH 17/33] A few more tests --- .../transforms/test_capture_decompose.py | 1 - .../test_capture_dynamic_decompositions.py | 140 ++++++++++++++---- 2 files changed, 115 insertions(+), 26 deletions(-) diff --git a/tests/capture/transforms/test_capture_decompose.py b/tests/capture/transforms/test_capture_decompose.py index 8494738d0e7..e2d834f43a8 100644 --- a/tests/capture/transforms/test_capture_decompose.py +++ b/tests/capture/transforms/test_capture_decompose.py @@ -30,7 +30,6 @@ ) from pennylane.transforms.decompose import DecomposeInterpreter, decompose_plxpr_to_plxpr - pytestmark = [pytest.mark.jax, pytest.mark.usefixtures("enable_disable_plxpr")] diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index d261e9f0ff6..43e8b92f1c5 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -19,16 +19,9 @@ jax = pytest.importorskip("jax") -from pennylane.capture.primitives import ( - cond_prim, - for_loop_prim, - qnode_prim, - while_loop_prim, -) +from pennylane.capture.primitives import cond_prim, for_loop_prim, qnode_prim, while_loop_prim from pennylane.operation import Operation -from pennylane.transforms.decompose import ( - DynamicDecomposeInterpreter, -) +from pennylane.transforms.decompose import DynamicDecomposeInterpreter pytestmark = [pytest.mark.jax, pytest.mark.usefixtures("enable_disable_plxpr")] @@ -169,6 +162,48 @@ def loop_fn(i): return qml.expval(qml.Z(0)) +class CustomOpNestedCond(Operation): + """Custom operation that contains a nested conditional in its decomposition""" + + num_wires = 1 + num_params = 1 + + def __init__(self, phi, wires, id=None): + super().__init__(phi, wires=wires, id=id) + + def _plxpr_decomposition(self) -> "jax.core.Jaxpr": + + return jax.make_jaxpr(self._compute_plxpr_decomposition)( + *self.parameters, *self.wires, **self.hyperparameters + ) + + @staticmethod + def _compute_plxpr_decomposition(phi, wires): + + def true_fn(phi, wires): + + @qml.for_loop(0, 3, 1) + def loop_rx(i, phi): + qml.RX(phi, wires=wires) + return jax.numpy.sin(phi) + + # pylint: disable=unused-variable + loop_rx(phi) + + def false_fn(phi, wires): + + @qml.while_loop(lambda i: i < 3) + def loop_fn(i): + qml.RX(phi, wires=wires) + return i + 1 + + _ = loop_fn(0) + + qml.cond(phi > 0.5, true_fn, false_fn)(phi, wires) + + qml.RX(phi, wires=wires) + + class CustomOpAutograph(Operation): """Custom operation that contains a nested conditional in its decomposition""" @@ -255,7 +290,7 @@ def circuit_comparison(): qml.Hadamard(wires=0) return qml.expval(qml.Z(0)) - assert jax.numpy.allclose(*result, circuit_comparison()) + assert qml.math.allclose(*result, circuit_comparison()) @pytest.mark.parametrize("wires", [[0, 1, 2, 3], [2, 3, 1, 0]]) def test_multi_wire(self, wires): @@ -265,7 +300,7 @@ def test_multi_wire(self, wires): @qml.qnode(device=qml.device("default.qubit", wires=4)) def circuit(x, wires): CustomOpMultiWire(x, wires=wires) - return qml.expval(qml.Z(0)) + return qml.expval(qml.Z(0)), qml.expval(qml.Z(1)), qml.var(qml.Z(2)), qml.state() jaxpr = jax.make_jaxpr(circuit)(0.5, wires=wires) @@ -290,9 +325,14 @@ def circuit_comparison(x, wires): qml.RY(x, wires=wires[1]) qml.RZ(x, wires=wires[2]) qml.RX(0.2, wires=wires[3]) - return qml.expval(qml.Z(0)) + return qml.expval(qml.Z(0)), qml.expval(qml.Z(1)), qml.var(qml.Z(2)), qml.state() + + comparison_result = circuit_comparison(0.5, wires) - assert jax.numpy.allclose(*result, circuit_comparison(0.5, wires)) + assert qml.math.allclose(result[0], comparison_result[0]) + assert qml.math.allclose(result[1], comparison_result[1]) + assert qml.math.allclose(result[2], comparison_result[2]) + assert qml.math.allclose(result[3], comparison_result[3]) @pytest.mark.parametrize("wire", [0, 1]) @pytest.mark.parametrize("x", [0.2, 0.8]) @@ -303,7 +343,7 @@ def test_qnode_cond(self, x, wire): @qml.qnode(device=qml.device("default.qubit", wires=2)) def circuit(x, wire): CustomOpCond(x, wires=wire) - return qml.expval(qml.Z(0)) + return qml.expval(qml.Z(wires=wire)) jaxpr = jax.make_jaxpr(circuit)(x, wire=wire) @@ -331,9 +371,9 @@ def false_fn(x, wire): qml.cond(x > 0.5, true_fn, false_fn)(x, wire) - return qml.expval(qml.Z(0)) + return qml.expval(qml.Z(wires=wire)) - assert jax.numpy.allclose(*result, circuit_comparison(x, wire)) + assert qml.math.allclose(*result, circuit_comparison(x, wire)) @pytest.mark.parametrize("wire", [0, 1]) def test_qnode_for_loop(self, wire): @@ -343,7 +383,7 @@ def test_qnode_for_loop(self, wire): @qml.qnode(device=qml.device("default.qubit", wires=2)) def circuit(x, wire): CustomOpForLoop(x, wires=wire) - return qml.expval(qml.Z(0)) + return qml.expval(qml.Z(wires=wire)) jaxpr = jax.make_jaxpr(circuit)(0.5, wire) @@ -366,9 +406,9 @@ def loop_rx(i, phi): # pylint: disable=unused-variable loop_rx(x) - return qml.expval(qml.Z(0)) + return qml.expval(qml.Z(wires=wire)) - assert jax.numpy.allclose(*result, circuit_comparison(0.5, wire)) + assert qml.math.allclose(*result, circuit_comparison(0.5, wire)) @pytest.mark.parametrize("wire", [0, 1]) def test_qnode_while_loop(self, wire): @@ -378,7 +418,7 @@ def test_qnode_while_loop(self, wire): @qml.qnode(device=qml.device("default.qubit", wires=2)) def circuit(x, wire): CustomOpWhileLoop(x, wires=wire) - return qml.expval(qml.Z(0)) + return qml.expval(qml.Z(wires=wire)) jaxpr = jax.make_jaxpr(circuit)(0.5, wire) @@ -400,9 +440,59 @@ def loop_fn(i): _ = loop_fn(0) - return qml.expval(qml.Z(0)) + return qml.expval(qml.Z(wires=wire)) + + assert qml.math.allclose(*result, circuit_comparison(0.5, wire)) + + @pytest.mark.parametrize("wire", [0, 1]) + @pytest.mark.parametrize("x", [0.2, 0.8]) + def test_qnode_nested_cond(self, x, wire): + """Test that a QNode with a nested conditional custom operation is correctly decomposed.""" + + @DynamicDecomposeInterpreter() + @qml.qnode(device=qml.device("default.qubit", wires=2)) + def circuit(x, wire): + CustomOpNestedCond(x, wires=wire) + return qml.expval(qml.Z(wires=wire)) + + jaxpr = jax.make_jaxpr(circuit)(x, wire) - assert jax.numpy.allclose(*result, circuit_comparison(0.5, wire)) + assert jaxpr.eqns[0].primitive == qnode_prim + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[1].primitive == cond_prim + assert qfunc_jaxpr.eqns[1].params["jaxpr_branches"][0].eqns[0].primitive == for_loop_prim + assert qfunc_jaxpr.eqns[1].params["jaxpr_branches"][1].eqns[0].primitive == while_loop_prim + assert qfunc_jaxpr.eqns[2].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[3].primitive == qml.PauliZ._primitive + assert qfunc_jaxpr.eqns[4].primitive == qml.measurements.ExpectationMP._obs_primitive + + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, wire) + + @qml.qnode(device=qml.device("default.qubit", wires=2)) + def circuit_comparison(x, wire): + def true_fn(x, wire): + + @qml.for_loop(0, 3, 1) + def loop_rx(i, phi): + qml.RX(phi, wires=wire) + return jax.numpy.sin(phi) + + # pylint: disable=unused-variable + loop_rx(x) + + def false_fn(x, wire): + @qml.while_loop(lambda i: i < 3) + def loop_fn(i): + qml.RX(x, wires=wire) + return i + 1 + + _ = loop_fn(0) + + qml.cond(x > 0.5, true_fn, false_fn)(x, wire) + qml.RX(x, wires=wire) + return qml.expval(qml.Z(wires=wire)) + + assert qml.math.allclose(*result, circuit_comparison(x, wire)) @pytest.mark.parametrize("wire", [0, 1]) @pytest.mark.parametrize("x", [0.2, 0.8]) @@ -413,7 +503,7 @@ def test_qnode_autograph(self, x, wire): @qml.qnode(device=qml.device("default.qubit", wires=2)) def circuit(x, wire): CustomOpAutograph(x, wires=wire) - return qml.expval(qml.Z(0)) + return qml.expval(qml.Z(wires=wire)) jaxpr = jax.make_jaxpr(circuit)(x, wire) @@ -439,7 +529,7 @@ def circuit_comparison(x, wire): else: qml.RY(x, wires=wire) - return qml.expval(qml.Z(0)) + return qml.expval(qml.Z(wires=wire)) # Autograph requires to capture the function first jaxpr_comparison = qml.capture.make_plxpr(circuit_comparison)(x, wire) @@ -447,4 +537,4 @@ def circuit_comparison(x, wire): jaxpr_comparison.jaxpr, jaxpr_comparison.consts, x, wire ) - assert jax.numpy.allclose(*result, *result_comparison) + assert qml.math.allclose(*result, *result_comparison) From 9fef95b6f65cfa9f3ee4e937bf826d40f60868b7 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 27 Jan 2025 14:23:36 -0500 Subject: [PATCH 18/33] Changelog --- doc/releases/changelog-dev.md | 20 ++++++++----- .../test_capture_dynamic_decompositions.py | 28 ++++++------------- 2 files changed, 21 insertions(+), 27 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 558070c802c..bb1de59b430 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -6,13 +6,6 @@

Improvements 🛠

-* The higher order primitives in program capture can now accept inputs with abstract shapes. - [(#6786)](https://github.com/PennyLaneAI/pennylane/pull/6786) - -* The `PlxprInterpreter` classes can now handle creating dynamic arrays via `jnp.ones`, `jnp.zeros`, - `jnp.arange`, and `jnp.full`. - [#6865)](https://github.com/PennyLaneAI/pennylane/pull/6865) - * `QNode` objects now have an `update` method that allows for re-configuring settings like `diff_method`, `mcm_method`, and more. This allows for easier on-the-fly adjustments to workflows. Any arguments not specified will retain their original value. [(#6803)](https://github.com/PennyLaneAI/pennylane/pull/6803) @@ -54,6 +47,19 @@ * The requested `diff_method` is now validated when program capture is enabled. [(#6852)](https://github.com/PennyLaneAI/pennylane/pull/6852) + +

Capturing and representing hybrid programs

+ +* Implemented a new `DynamicDecomposeInterpreter` to capture decompositions of operators with control-flow instructions. + [(#6859)](https://github.com/PennyLaneAI/pennylane/pull/6859) + +* The higher order primitives in program capture can now accept inputs with abstract shapes. + [(#6786)](https://github.com/PennyLaneAI/pennylane/pull/6786) + +* The `PlxprInterpreter` classes can now handle creating dynamic arrays via `jnp.ones`, `jnp.zeros`, + `jnp.arange`, and `jnp.full`. + [#6865)](https://github.com/PennyLaneAI/pennylane/pull/6865) +

Breaking changes 💔

* `MultiControlledX` no longer accepts strings as control values. diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index 43e8b92f1c5..688841e821c 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -37,9 +37,7 @@ def _init__(self, phi, wires, id=None): def _plxpr_decomposition(self) -> "jax.core.Jaxpr": - return jax.make_jaxpr(self._compute_plxpr_decomposition)( - *self.parameters, *self.wires, **self.hyperparameters - ) + return jax.make_jaxpr(self._compute_plxpr_decomposition)(*self.parameters, *self.wires) @staticmethod def _compute_plxpr_decomposition(wires): @@ -89,9 +87,7 @@ def __init__(self, phi, wires, id=None): def _plxpr_decomposition(self) -> "jax.core.Jaxpr": - return jax.make_jaxpr(self._compute_plxpr_decomposition)( - *self.parameters, *self.wires, **self.hyperparameters - ) + return jax.make_jaxpr(self._compute_plxpr_decomposition)(*self.parameters, *self.wires) @staticmethod def _compute_plxpr_decomposition(phi, wires): @@ -116,9 +112,7 @@ def __init__(self, phi, wires, id=None): def _plxpr_decomposition(self) -> "jax.core.Jaxpr": - return jax.make_jaxpr(self._compute_plxpr_decomposition)( - *self.parameters, *self.wires, **self.hyperparameters - ) + return jax.make_jaxpr(self._compute_plxpr_decomposition)(*self.parameters, *self.wires) @staticmethod def _compute_plxpr_decomposition(phi, wires): @@ -131,8 +125,6 @@ def loop_rx(i, phi): # pylint: disable=unused-variable loop_rx(phi) - return qml.expval(qml.Z(0)) - class CustomOpWhileLoop(Operation): """Custom operation that contains a while loop in its decomposition""" @@ -145,9 +137,7 @@ def __init__(self, phi, wires, id=None): def _plxpr_decomposition(self) -> "jax.core.Jaxpr": - return jax.make_jaxpr(self._compute_plxpr_decomposition)( - *self.parameters, *self.wires, **self.hyperparameters - ) + return jax.make_jaxpr(self._compute_plxpr_decomposition)(*self.parameters, *self.wires) @staticmethod def _compute_plxpr_decomposition(phi, wires): @@ -173,9 +163,7 @@ def __init__(self, phi, wires, id=None): def _plxpr_decomposition(self) -> "jax.core.Jaxpr": - return jax.make_jaxpr(self._compute_plxpr_decomposition)( - *self.parameters, *self.wires, **self.hyperparameters - ) + return jax.make_jaxpr(self._compute_plxpr_decomposition)(*self.parameters, *self.wires) @staticmethod def _compute_plxpr_decomposition(phi, wires): @@ -216,7 +204,7 @@ def __init__(self, phi, wires, id=None): def _plxpr_decomposition(self) -> "jax.core.Jaxpr": return qml.capture.make_plxpr(self._compute_plxpr_decomposition)( - *self.parameters, *self.wires, **self.hyperparameters + *self.parameters, *self.wires ) @staticmethod @@ -300,7 +288,7 @@ def test_multi_wire(self, wires): @qml.qnode(device=qml.device("default.qubit", wires=4)) def circuit(x, wires): CustomOpMultiWire(x, wires=wires) - return qml.expval(qml.Z(0)), qml.expval(qml.Z(1)), qml.var(qml.Z(2)), qml.state() + return qml.expval(qml.Z(0)), qml.probs(wires=1), qml.var(qml.Z(2)), qml.state() jaxpr = jax.make_jaxpr(circuit)(0.5, wires=wires) @@ -325,7 +313,7 @@ def circuit_comparison(x, wires): qml.RY(x, wires=wires[1]) qml.RZ(x, wires=wires[2]) qml.RX(0.2, wires=wires[3]) - return qml.expval(qml.Z(0)), qml.expval(qml.Z(1)), qml.var(qml.Z(2)), qml.state() + return qml.expval(qml.Z(0)), qml.probs(wires=1), qml.var(qml.Z(2)), qml.state() comparison_result = circuit_comparison(0.5, wires) From 4a56150907f09dfbd70f1a223cf9a07fdd2f6274 Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Tue, 28 Jan 2025 10:39:14 -0500 Subject: [PATCH 19/33] Removing redundant operations --- pennylane/transforms/decompose.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index 261086d7f78..b7a2cc8e591 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -202,7 +202,7 @@ def _get_plxpr_dynamic_decompose(): # pylint: disable=missing-docstring class DynamicDecomposeInterpreter(qml.capture.PlxprInterpreter): """ - Experimental Plxpr Interpreter for applying a dynamic decomposition to operations program capture is enabled. + Experimental Plxpr Interpreter for applying a dynamic decomposition to operations when program capture is enabled. """ def eval_dynamic_decomposition(self, jaxpr_decomp: "jax.core.Jaxpr", *args): @@ -252,16 +252,20 @@ def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): with qml.QueuingManager.stop_recording(): op = eqn.primitive.impl(*invals, **eqn.params) - if hasattr(op, "_compute_plxpr_decomposition"): + if isinstance(eqn.outvars[0], jax.core.DropVar): - jaxpr_decomp = op._plxpr_decomposition() - args = (*op.parameters, *op.wires, *op.hyperparameters.values()) + if hasattr(op, "_plxpr_decomposition"): - # We assume that the JAXPR of the decomposition does not contain constants - # and that all the required parameters are passed as arguments - return self.eval_dynamic_decomposition(jaxpr_decomp.jaxpr, *args) + jaxpr_decomp = op._plxpr_decomposition() + args = (*op.parameters, *op.wires, *op.hyperparameters.values()) - return super().interpret_operation_eqn(eqn) + # We assume that the JAXPR of the decomposition does not contain constants + # and that all the required parameters are passed as arguments + return self.eval_dynamic_decomposition(jaxpr_decomp.jaxpr, *args) + + return super().interpret_operation(op) + + return op return DynamicDecomposeInterpreter From 9792df6a4752c8b24e4716398ba542913c8d67b3 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Wed, 29 Jan 2025 15:49:00 -0500 Subject: [PATCH 20/33] Pre-binding hyperparameters [ci skip] --- pennylane/transforms/decompose.py | 5 +---- .../test_capture_dynamic_decompositions.py | 20 +++++++++++-------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index b7a2cc8e591..39409c5eb31 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -257,10 +257,7 @@ def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): if hasattr(op, "_plxpr_decomposition"): jaxpr_decomp = op._plxpr_decomposition() - args = (*op.parameters, *op.wires, *op.hyperparameters.values()) - - # We assume that the JAXPR of the decomposition does not contain constants - # and that all the required parameters are passed as arguments + args = (*op.parameters, *op.wires) return self.eval_dynamic_decomposition(jaxpr_decomp.jaxpr, *args) return super().interpret_operation(op) diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index 688841e821c..0e7f6008c42 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -13,6 +13,8 @@ # limitations under the License. """Unit tests for the ``DynamicDecomposeInterpreter`` class.""" # pylint:disable=protected-access,unused-argument, wrong-import-position, no-value-for-parameter, too-few-public-methods +from functools import partial + import pytest import pennylane as qml @@ -58,22 +60,24 @@ def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) def _plxpr_decomposition(self) -> "jax.core.Jaxpr": - - return jax.make_jaxpr(self._compute_plxpr_decomposition)( - *self.parameters, *self.wires, *self.hyperparameters.values() + args = (*self.parameters, *self.wires) + return jax.make_jaxpr(partial(self._compute_plxpr_decomposition, **self.hyperparameters))( + *args ) @staticmethod - def _compute_plxpr_decomposition(phi, *args): - wires = args[:4] - hyperparameters = args[4:] + def _compute_plxpr_decomposition(*args, **hyperparameters): + + phi = args[0] + wires = args[1:] + qml.CNOT(wires=[wires[0], wires[1]]) qml.DoubleExcitation(phi, wires=wires) qml.CNOT(wires=[wires[0], wires[1]]) - qml.RX(hyperparameters[0], wires=wires[0]) + qml.RX(hyperparameters["key_1"], wires=wires[0]) qml.RY(phi, wires=wires[1]) qml.RZ(phi, wires=wires[2]) - qml.RX(hyperparameters[1], wires=wires[3]) + qml.RX(hyperparameters["key_2"], wires=wires[3]) class CustomOpCond(Operation): From aa422b35080c1dd75953f28eaa9a7832fa72c6a7 Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Thu, 30 Jan 2025 10:11:40 -0500 Subject: [PATCH 21/33] Removing redundant method --- pennylane/transforms/decompose.py | 5 +- .../test_capture_dynamic_decompositions.py | 64 +++++-------------- 2 files changed, 20 insertions(+), 49 deletions(-) diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index 39409c5eb31..3c6181239ff 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -256,8 +256,11 @@ def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): if hasattr(op, "_plxpr_decomposition"): - jaxpr_decomp = op._plxpr_decomposition() args = (*op.parameters, *op.wires) + jaxpr_decomp = qml.capture.make_plxpr( + partial(op._plxpr_decomposition, **op.hyperparameters) + )(*args) + return self.eval_dynamic_decomposition(jaxpr_decomp.jaxpr, *args) return super().interpret_operation(op) diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index 0e7f6008c42..c1631e22a55 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -37,12 +37,8 @@ class SimpleCustomOp(Operation): def _init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) - def _plxpr_decomposition(self) -> "jax.core.Jaxpr": - - return jax.make_jaxpr(self._compute_plxpr_decomposition)(*self.parameters, *self.wires) - @staticmethod - def _compute_plxpr_decomposition(wires): + def _plxpr_decomposition(wires): qml.Hadamard(wires=wires) @@ -59,25 +55,19 @@ def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) - def _plxpr_decomposition(self) -> "jax.core.Jaxpr": - args = (*self.parameters, *self.wires) - return jax.make_jaxpr(partial(self._compute_plxpr_decomposition, **self.hyperparameters))( - *args - ) - @staticmethod - def _compute_plxpr_decomposition(*args, **hyperparameters): + def _plxpr_decomposition(*args, **hyperparameters): phi = args[0] wires = args[1:] - qml.CNOT(wires=[wires[0], wires[1]]) - qml.DoubleExcitation(phi, wires=wires) - qml.CNOT(wires=[wires[0], wires[1]]) - qml.RX(hyperparameters["key_1"], wires=wires[0]) - qml.RY(phi, wires=wires[1]) - qml.RZ(phi, wires=wires[2]) - qml.RX(hyperparameters["key_2"], wires=wires[3]) + qml.CNOT([wires[0], wires[1]]) + qml.DoubleExcitation(phi, wires) + qml.CNOT([wires[0], wires[1]]) + qml.RX(hyperparameters["key_1"], wires[0]) + qml.RY(phi, wires[1]) + qml.RZ(phi, wires[2]) + qml.RX(hyperparameters["key_2"], wires[3]) class CustomOpCond(Operation): @@ -89,12 +79,8 @@ class CustomOpCond(Operation): def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) - def _plxpr_decomposition(self) -> "jax.core.Jaxpr": - - return jax.make_jaxpr(self._compute_plxpr_decomposition)(*self.parameters, *self.wires) - @staticmethod - def _compute_plxpr_decomposition(phi, wires): + def _plxpr_decomposition(phi, wires): def true_fn(phi, wires): qml.RX(phi, wires=wires) @@ -114,12 +100,8 @@ class CustomOpForLoop(Operation): def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) - def _plxpr_decomposition(self) -> "jax.core.Jaxpr": - - return jax.make_jaxpr(self._compute_plxpr_decomposition)(*self.parameters, *self.wires) - @staticmethod - def _compute_plxpr_decomposition(phi, wires): + def _plxpr_decomposition(phi, wires): @qml.for_loop(0, 3, 1) def loop_rx(i, phi): @@ -139,12 +121,8 @@ class CustomOpWhileLoop(Operation): def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) - def _plxpr_decomposition(self) -> "jax.core.Jaxpr": - - return jax.make_jaxpr(self._compute_plxpr_decomposition)(*self.parameters, *self.wires) - @staticmethod - def _compute_plxpr_decomposition(phi, wires): + def _plxpr_decomposition(phi, wires): @qml.while_loop(lambda i: i < 3) def loop_fn(i): @@ -165,18 +143,14 @@ class CustomOpNestedCond(Operation): def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) - def _plxpr_decomposition(self) -> "jax.core.Jaxpr": - - return jax.make_jaxpr(self._compute_plxpr_decomposition)(*self.parameters, *self.wires) - @staticmethod - def _compute_plxpr_decomposition(phi, wires): + def _plxpr_decomposition(phi, wires): def true_fn(phi, wires): @qml.for_loop(0, 3, 1) def loop_rx(i, phi): - qml.RX(phi, wires=wires) + qml.RX(phi, wires) return jax.numpy.sin(phi) # pylint: disable=unused-variable @@ -186,7 +160,7 @@ def false_fn(phi, wires): @qml.while_loop(lambda i: i < 3) def loop_fn(i): - qml.RX(phi, wires=wires) + qml.RX(phi, wires) return i + 1 _ = loop_fn(0) @@ -205,14 +179,8 @@ class CustomOpAutograph(Operation): def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) - def _plxpr_decomposition(self) -> "jax.core.Jaxpr": - - return qml.capture.make_plxpr(self._compute_plxpr_decomposition)( - *self.parameters, *self.wires - ) - @staticmethod - def _compute_plxpr_decomposition(phi, wires): + def _plxpr_decomposition(phi, wires): if phi > 0.5: qml.RX(phi, wires=wires) From b7a18cdd45077cca44b9f605905f2fae98709380 Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Thu, 30 Jan 2025 10:21:29 -0500 Subject: [PATCH 22/33] Pylint --- tests/capture/transforms/test_capture_dynamic_decompositions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index c1631e22a55..9fcf3b21f60 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -13,8 +13,6 @@ # limitations under the License. """Unit tests for the ``DynamicDecomposeInterpreter`` class.""" # pylint:disable=protected-access,unused-argument, wrong-import-position, no-value-for-parameter, too-few-public-methods -from functools import partial - import pytest import pennylane as qml From 97cba03dc57eb4001820771d1dc8b5165b27f991 Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Thu, 30 Jan 2025 15:40:31 -0500 Subject: [PATCH 23/33] Testing CI failures (JAX imports) --- pennylane/operation.py | 25 +++++++++++++++++++ pennylane/transforms/decompose.py | 8 ++---- .../test_capture_dynamic_decompositions.py | 24 +++++++++++------- 3 files changed, 42 insertions(+), 15 deletions(-) diff --git a/pennylane/operation.py b/pennylane/operation.py index 17546747dc4..925701d5a7f 100644 --- a/pennylane/operation.py +++ b/pennylane/operation.py @@ -226,6 +226,7 @@ from collections.abc import Hashable, Iterable from enum import IntEnum from typing import Any, Callable, Literal, Optional, Union +from functools import partial import numpy as np from scipy.sparse import csr_matrix @@ -1322,6 +1323,30 @@ def decomposition(self) -> list["Operator"]: *self.parameters, wires=self.wires, **self.hyperparameters ) + @classproperty + def _has_plxpr_decomposition(cls) -> bool: + """Whether or not the Operator returns a defined plxpr decomposition.""" + + return ( + cls._compute_plxpr_decomposition != Operator._compute_plxpr_decomposition + or cls._plxpr_decomposition != Operator._plxpr_decomposition + ) + + def _plxpr_decomposition(self) -> "jax.core.Jaxpr": + """Representation of the operator as a plxpr decomposition.""" + + args = (*self.parameters, *self.wires) + jaxpr_decomp = qml.capture.make_plxpr( + partial(self._compute_plxpr_decomposition, **self.hyperparameters) + )(*args) + + return jaxpr_decomp + + @staticmethod + def _compute_plxpr_decomposition(*args, **hyperparameters): + """Experimental method to compute the plxpr decomposition of the operator.""" + raise DecompositionUndefinedError + @staticmethod def compute_decomposition( *params: TensorLike, diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index 3c6181239ff..fb5cb03e6f2 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -254,13 +254,9 @@ def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): if isinstance(eqn.outvars[0], jax.core.DropVar): - if hasattr(op, "_plxpr_decomposition"): - + if op._has_plxpr_decomposition: + jaxpr_decomp = op._plxpr_decomposition() args = (*op.parameters, *op.wires) - jaxpr_decomp = qml.capture.make_plxpr( - partial(op._plxpr_decomposition, **op.hyperparameters) - )(*args) - return self.eval_dynamic_decomposition(jaxpr_decomp.jaxpr, *args) return super().interpret_operation(op) diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index 9fcf3b21f60..b5eb678b571 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -36,7 +36,7 @@ def _init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) @staticmethod - def _plxpr_decomposition(wires): + def _compute_plxpr_decomposition(wires): qml.Hadamard(wires=wires) @@ -54,7 +54,7 @@ def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) @staticmethod - def _plxpr_decomposition(*args, **hyperparameters): + def _compute_plxpr_decomposition(*args, **hyperparameters): phi = args[0] wires = args[1:] @@ -78,7 +78,7 @@ def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) @staticmethod - def _plxpr_decomposition(phi, wires): + def _compute_plxpr_decomposition(phi, wires): def true_fn(phi, wires): qml.RX(phi, wires=wires) @@ -99,11 +99,11 @@ def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) @staticmethod - def _plxpr_decomposition(phi, wires): + def _compute_plxpr_decomposition(phi, wires): @qml.for_loop(0, 3, 1) def loop_rx(i, phi): - qml.RX(phi, wires=wires) + qml.RX(phi, wires) return jax.numpy.sin(phi) # pylint: disable=unused-variable @@ -120,11 +120,11 @@ def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) @staticmethod - def _plxpr_decomposition(phi, wires): + def _compute_plxpr_decomposition(phi, wires): @qml.while_loop(lambda i: i < 3) def loop_fn(i): - qml.RX(phi, wires=wires) + qml.RX(phi, wires) return i + 1 _ = loop_fn(0) @@ -142,7 +142,7 @@ def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) @staticmethod - def _plxpr_decomposition(phi, wires): + def _compute_plxpr_decomposition(phi, wires): def true_fn(phi, wires): @@ -178,7 +178,7 @@ def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) @staticmethod - def _plxpr_decomposition(phi, wires): + def _compute_plxpr_decomposition(phi, wires): if phi > 0.5: qml.RX(phi, wires=wires) @@ -190,6 +190,12 @@ def _plxpr_decomposition(phi, wires): class TestDynamicDecomposeInterpreter: """Tests for the DynamicDecomposeInterpreter class""" + def test_error_no_plxpr_decomposition(self): + """Test that an error is raised if an operator does not have a plxpr decomposition.""" + + with pytest.raises(qml.operation.DecompositionUndefinedError): + qml.RX(0.1, 0)._compute_plxpr_decomposition() + def test_no_plxpr_decomposition(self): """Test that a function with a custom operation that does not have a plxpr decomposition is not decomposed.""" From e05963e0db9b9c6213a990264fe51e95c290434b Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Thu, 30 Jan 2025 16:00:50 -0500 Subject: [PATCH 24/33] isort --- pennylane/operation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/operation.py b/pennylane/operation.py index 925701d5a7f..8936d96ad1b 100644 --- a/pennylane/operation.py +++ b/pennylane/operation.py @@ -225,8 +225,8 @@ import warnings from collections.abc import Hashable, Iterable from enum import IntEnum -from typing import Any, Callable, Literal, Optional, Union from functools import partial +from typing import Any, Callable, Literal, Optional, Union import numpy as np from scipy.sparse import csr_matrix From bc594caf81b8a9daba8f3208416e2d8aacb7f5f3 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Fri, 31 Jan 2025 14:32:47 -0500 Subject: [PATCH 25/33] Support for consts and hyperparameters --- pennylane/capture/base_interpreter.py | 84 ++++++++++++++----- pennylane/transforms/decompose.py | 9 +- .../test_capture_dynamic_decompositions.py | 76 +++++++++++++++++ 3 files changed, 144 insertions(+), 25 deletions(-) diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index 43fdb5ee32a..97c059c0749 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -97,7 +97,7 @@ def jaxpr_to_jaxpr( f = partial(interpreter.eval, jaxpr, consts) - return jax.make_jaxpr(f)(*args).jaxpr + return jax.make_jaxpr(f)(*args) class PlxprInterpreter: @@ -454,9 +454,11 @@ def handle_adjoint_transform(self, *invals, jaxpr, lazy, n_consts): """Interpret an adjoint transform primitive.""" consts = invals[:n_consts] args = invals[n_consts:] - jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args) - return adjoint_transform_prim.bind(*invals, jaxpr=jaxpr, lazy=lazy, n_consts=n_consts) + + return adjoint_transform_prim.bind( + *jaxpr.consts, *args, jaxpr=jaxpr.jaxpr, lazy=lazy, n_consts=len(jaxpr.consts) + ) # pylint: disable=too-many-arguments @@ -468,12 +470,14 @@ def handle_ctrl_transform(self, *invals, n_control, jaxpr, control_values, work_ jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args) return ctrl_transform_prim.bind( - *invals, + *jaxpr.consts, + *args, + *invals[-n_control:], n_control=n_control, - jaxpr=jaxpr, + jaxpr=jaxpr.jaxpr, control_values=control_values, work_wires=work_wires, - n_consts=n_consts, + n_consts=len(jaxpr.consts), ) @@ -482,19 +486,24 @@ def handle_for_loop( self, start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice ): """Handle a for loop primitive.""" + consts = args[consts_slice] init_state = args[args_slice] abstract_shapes = args[abstract_shapes_slice] - new_jaxpr_body_fn = jaxpr_to_jaxpr( - copy(self), jaxpr_body_fn, args[consts_slice], *abstract_shapes, start, *init_state + copy(self), jaxpr_body_fn, consts, *abstract_shapes, start, *init_state ) + consts_slice = slice(0, len(new_jaxpr_body_fn.consts)) + abstract_shapes_slice = slice(consts_slice.stop, consts_slice.stop + len(abstract_shapes)) + args_slice = slice(abstract_shapes_slice.stop, None) return for_loop_prim.bind( start, stop, step, - *args, - jaxpr_body_fn=new_jaxpr_body_fn, + *new_jaxpr_body_fn.consts, + *abstract_shapes, + *init_state, + jaxpr_body_fn=new_jaxpr_body_fn.jaxpr, consts_slice=consts_slice, args_slice=args_slice, abstract_shapes_slice=abstract_shapes_slice, @@ -507,15 +516,30 @@ def handle_cond(self, *invals, jaxpr_branches, consts_slices, args_slice): args = invals[args_slice] new_jaxprs = [] + new_consts = [] + new_consts_slices = [] + end_const_ind = len(jaxpr_branches) + for const_slice, jaxpr in zip(consts_slices, jaxpr_branches): consts = invals[const_slice] if jaxpr is None: new_jaxprs.append(None) + new_consts_slices.append(slice(0, 0)) else: - new_jaxprs.append(jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args)) + new_jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args) + new_jaxprs.append(new_jaxpr.jaxpr) + new_consts.extend(new_jaxpr.consts) + new_consts_slices.append(slice(end_const_ind, end_const_ind + len(new_jaxpr.consts))) + end_const_ind += len(new_jaxpr.consts) + new_args_slice = slice(end_const_ind, None) return cond_prim.bind( - *invals, jaxpr_branches=new_jaxprs, consts_slices=consts_slices, args_slice=args_slice + *invals[: len(jaxpr_branches)], + *new_consts, + *args, + jaxpr_branches=new_jaxprs, + consts_slices=new_consts_slices, + args_slice=new_args_slice, ) @@ -543,12 +567,20 @@ def handle_while_loop( copy(self), jaxpr_cond_fn, consts_cond, *abstract_shapes, *init_state ) + body_consts = slice(0, len(new_jaxpr_body_fn.consts)) + cond_consts = slice(body_consts.stop, body_consts.stop + len(new_jaxpr_cond_fn.consts)) + abstract_shapes_slice = slice(cond_consts.stop, cond_consts.stop + len(abstract_shapes)) + args_slice = slice(abstract_shapes_slice.stop, None) + return while_loop_prim.bind( - *invals, - jaxpr_body_fn=new_jaxpr_body_fn, - jaxpr_cond_fn=new_jaxpr_cond_fn, - body_slice=body_slice, - cond_slice=cond_slice, + *new_jaxpr_body_fn.consts, + *new_jaxpr_cond_fn.consts, + *abstract_shapes, + *init_state, + jaxpr_body_fn=new_jaxpr_body_fn.jaxpr, + jaxpr_cond_fn=new_jaxpr_cond_fn.jaxpr, + body_slice=body_consts, + cond_slice=cond_consts, args_slice=args_slice, abstract_shapes_slice=abstract_shapes_slice, ) @@ -559,17 +591,19 @@ def handle_while_loop( def handle_qnode(self, *invals, shots, qnode, device, qnode_kwargs, qfunc_jaxpr, n_consts): """Handle a qnode primitive.""" consts = invals[:n_consts] + args = invals[n_consts:] - new_qfunc_jaxpr = jaxpr_to_jaxpr(copy(self), qfunc_jaxpr, consts, *invals[n_consts:]) + new_qfunc_jaxpr = jaxpr_to_jaxpr(copy(self), qfunc_jaxpr, consts, *args) return qnode_prim.bind( - *invals, + *new_qfunc_jaxpr.consts, + *args, shots=shots, qnode=qnode, device=device, qnode_kwargs=qnode_kwargs, - qfunc_jaxpr=new_qfunc_jaxpr, - n_consts=n_consts, + qfunc_jaxpr=new_qfunc_jaxpr.jaxpr, + n_consts=len(new_qfunc_jaxpr.consts), ) @@ -579,7 +613,9 @@ def handle_grad(self, *invals, jaxpr, n_consts, **params): consts = invals[:n_consts] args = invals[n_consts:] new_jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args) - return grad_prim.bind(*invals, jaxpr=new_jaxpr, n_consts=n_consts, **params) + return grad_prim.bind( + *new_jaxpr.consts, *args, jaxpr=new_jaxpr.jaxpr, n_consts=len(new_jaxpr.consts), **params + ) @PlxprInterpreter.register_primitive(jacobian_prim) @@ -588,7 +624,9 @@ def handle_jacobian(self, *invals, jaxpr, n_consts, **params): consts = invals[:n_consts] args = invals[n_consts:] new_jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args) - return jacobian_prim.bind(*invals, jaxpr=new_jaxpr, n_consts=n_consts, **params) + return jacobian_prim.bind( + *new_jaxpr.consts, *args, jaxpr=new_jaxpr.jaxpr, n_consts=len(new_jaxpr.consts), **params + ) def flatten_while_loop( diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index fb5cb03e6f2..a1d480dc798 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -205,7 +205,7 @@ class DynamicDecomposeInterpreter(qml.capture.PlxprInterpreter): Experimental Plxpr Interpreter for applying a dynamic decomposition to operations when program capture is enabled. """ - def eval_dynamic_decomposition(self, jaxpr_decomp: "jax.core.Jaxpr", *args): + def eval_dynamic_decomposition(self, jaxpr_decomp: "jax.core.Jaxpr", consts, *args): """ Evaluate a dynamic decomposition of a Jaxpr. @@ -217,6 +217,9 @@ def eval_dynamic_decomposition(self, jaxpr_decomp: "jax.core.Jaxpr", *args): for arg, invar in zip(args, jaxpr_decomp.invars, strict=True): self._env[invar] = arg + for const, constvar in zip(consts, jaxpr_decomp.constvars, strict=True): + self._env[constvar] = const + for inner_eqn in jaxpr_decomp.eqns: custom_handler = self._primitive_registrations.get(inner_eqn.primitive, None) @@ -257,7 +260,9 @@ def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): if op._has_plxpr_decomposition: jaxpr_decomp = op._plxpr_decomposition() args = (*op.parameters, *op.wires) - return self.eval_dynamic_decomposition(jaxpr_decomp.jaxpr, *args) + return self.eval_dynamic_decomposition( + jaxpr_decomp.jaxpr, jaxpr_decomp.consts, *args + ) return super().interpret_operation(op) diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index b5eb678b571..a77483b9e5f 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -40,6 +40,41 @@ def _compute_plxpr_decomposition(wires): qml.Hadamard(wires=wires) +const = jax.numpy.array(0.1) + + +class CustomOpConstHyperparams(Operation): + """Custom operation that contains constants and hyperparameters in its decomposition""" + + num_wires = 4 + num_params = 1 + + def __init__(self, phi, wires, id=None): + + self._hyperparameters = { + "key": const, + "CNOT": qml.CNOT, + "RX": qml.RX, + "phi": phi, + } + + super().__init__(phi, wires=wires, id=id) + + @staticmethod + def _compute_plxpr_decomposition(*args, **hyperparameters): + + phi = args[0] + wires = args[1:] + + hyperparameters["CNOT"](wires=[wires[0], wires[1]]) + hyperparameters["RX"](phi, wires=wires[2]) + hyperparameters["RX"](hyperparameters["key"], wires=wires[0]) + hyperparameters["RX"](const, wires=wires[3]) + + qml.RY(hyperparameters["key"], wires[0]) + qml.RZ(hyperparameters["phi"], wires[2]) + + class CustomOpMultiWire(Operation): """Custom operation that acts on multiple wires""" @@ -298,6 +333,47 @@ def circuit_comparison(x, wires): assert qml.math.allclose(result[2], comparison_result[2]) assert qml.math.allclose(result[3], comparison_result[3]) + @pytest.mark.parametrize("wires", [[0, 1, 2, 3], [2, 3, 1, 0]]) + @pytest.mark.parametrize("x", [0.2, 0.8]) + def test_qnode_const_hyperparams(self, wires, x): + """Test that a QNode with a constant in the custom operation is correctly decomposed.""" + + @DynamicDecomposeInterpreter() + @qml.qnode(device=qml.device("default.qubit", wires=4)) + def circuit(x, wires): + CustomOpConstHyperparams(x, wires=wires) + return qml.expval(qml.Z(0)), qml.probs(wires=1), qml.var(qml.Z(2)), qml.state() + + jaxpr = jax.make_jaxpr(circuit)(x, wires=wires) + + assert jaxpr.eqns[0].primitive == qnode_prim + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[0].primitive == qml.CNOT._primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[2].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[3].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[4].primitive == qml.RY._primitive + assert qfunc_jaxpr.eqns[5].primitive == qml.RZ._primitive + + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, *wires) + + @qml.qnode(device=qml.device("default.qubit", wires=4)) + def circuit_comparison(x, wires): + qml.CNOT([wires[0], wires[1]]) + qml.RX(x, wires=wires[2]) + qml.RX(0.1, wires=wires[0]) + qml.RX(0.1, wires=wires[3]) + qml.RY(0.1, wires=wires[0]) + qml.RZ(x, wires=wires[2]) + return qml.expval(qml.Z(0)), qml.probs(wires=1), qml.var(qml.Z(2)), qml.state() + + comparison_result = circuit_comparison(x, wires) + + assert qml.math.allclose(result[0], comparison_result[0]) + assert qml.math.allclose(result[1], comparison_result[1]) + assert qml.math.allclose(result[2], comparison_result[2]) + assert qml.math.allclose(result[3], comparison_result[3]) + @pytest.mark.parametrize("wire", [0, 1]) @pytest.mark.parametrize("x", [0.2, 0.8]) def test_qnode_cond(self, x, wire): From 1ea4a14df4389bddb8c062708644498a05b1b9cc Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Fri, 31 Jan 2025 14:59:36 -0500 Subject: [PATCH 26/33] Fixes neede after autograph PR merged on master --- .../test_capture_dynamic_decompositions.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index a77483b9e5f..fc9d342ac98 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -266,7 +266,7 @@ def test_qnode_simple(self): """Test that a QNode with a custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() - @qml.qnode(device=qml.device("default.qubit", wires=2)) + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=False) def circuit(): qml.RY(0.1, wires=0) SimpleCustomOp(wires=0) @@ -296,7 +296,7 @@ def test_multi_wire(self, wires): """Test that a QNode with a multi-wire custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() - @qml.qnode(device=qml.device("default.qubit", wires=4)) + @qml.qnode(device=qml.device("default.qubit", wires=4), autograph=False) def circuit(x, wires): CustomOpMultiWire(x, wires=wires) return qml.expval(qml.Z(0)), qml.probs(wires=1), qml.var(qml.Z(2)), qml.state() @@ -339,7 +339,7 @@ def test_qnode_const_hyperparams(self, wires, x): """Test that a QNode with a constant in the custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() - @qml.qnode(device=qml.device("default.qubit", wires=4)) + @qml.qnode(device=qml.device("default.qubit", wires=4), autograph=False) def circuit(x, wires): CustomOpConstHyperparams(x, wires=wires) return qml.expval(qml.Z(0)), qml.probs(wires=1), qml.var(qml.Z(2)), qml.state() @@ -380,7 +380,7 @@ def test_qnode_cond(self, x, wire): """Test that a QNode with a conditional custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() - @qml.qnode(device=qml.device("default.qubit", wires=2)) + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=False) def circuit(x, wire): CustomOpCond(x, wires=wire) return qml.expval(qml.Z(wires=wire)) @@ -420,7 +420,7 @@ def test_qnode_for_loop(self, wire): """Test that a QNode with a for loop custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() - @qml.qnode(device=qml.device("default.qubit", wires=2)) + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=False) def circuit(x, wire): CustomOpForLoop(x, wires=wire) return qml.expval(qml.Z(wires=wire)) @@ -455,7 +455,7 @@ def test_qnode_while_loop(self, wire): """Test that a QNode with a while loop custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() - @qml.qnode(device=qml.device("default.qubit", wires=2)) + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=False) def circuit(x, wire): CustomOpWhileLoop(x, wires=wire) return qml.expval(qml.Z(wires=wire)) @@ -471,7 +471,7 @@ def circuit(x, wire): result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.5, wire) - @qml.qnode(device=qml.device("default.qubit", wires=2)) + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=False) def circuit_comparison(x, wire): @qml.while_loop(lambda i: i < 3) def loop_fn(i): @@ -490,7 +490,7 @@ def test_qnode_nested_cond(self, x, wire): """Test that a QNode with a nested conditional custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() - @qml.qnode(device=qml.device("default.qubit", wires=2)) + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=False) def circuit(x, wire): CustomOpNestedCond(x, wires=wire) return qml.expval(qml.Z(wires=wire)) @@ -508,7 +508,7 @@ def circuit(x, wire): result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, wire) - @qml.qnode(device=qml.device("default.qubit", wires=2)) + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=False) def circuit_comparison(x, wire): def true_fn(x, wire): @@ -540,7 +540,7 @@ def test_qnode_autograph(self, x, wire): """Test that a QNode with a nested conditional custom operation is correctly decomposed.""" @DynamicDecomposeInterpreter() - @qml.qnode(device=qml.device("default.qubit", wires=2)) + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=False) def circuit(x, wire): CustomOpAutograph(x, wires=wire) return qml.expval(qml.Z(wires=wire)) @@ -561,7 +561,7 @@ def circuit(x, wire): result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, wire) - @qml.qnode(device=qml.device("default.qubit", wires=2)) + @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=False) def circuit_comparison(x, wire): if x > 0.5: From f576486b92db0191ef64340fd4b5daa0dfeb2661 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Wed, 5 Feb 2025 10:24:03 -0500 Subject: [PATCH 27/33] [ci skip] From fd817cca806f86fde3f668affe5233dbce0d32b5 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Wed, 5 Feb 2025 16:03:56 -0500 Subject: [PATCH 28/33] Removing `DynamicDecomposeInterpreter` --- doc/releases/changelog-dev.md | 3 +- pennylane/operation.py | 51 ++++----- pennylane/transforms/decompose.py | 100 +++--------------- .../test_capture_dynamic_decompositions.py | 63 ++++++----- 4 files changed, 76 insertions(+), 141 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 70c7e65197c..cd786fe6d1f 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -150,7 +150,8 @@

Capturing and representing hybrid programs

-* Implemented a new `DynamicDecomposeInterpreter` to capture decompositions of operators with control-flow instructions. +* Implemented a `compute_plxpr_decomposition` method in the `qml.operation.Operator` class to apply dynamic decompositions + with program capture enabled. [(#6859)](https://github.com/PennyLaneAI/pennylane/pull/6859) * The higher order primitives in program capture can now accept inputs with abstract shapes. diff --git a/pennylane/operation.py b/pennylane/operation.py index 8936d96ad1b..7a71c158aab 100644 --- a/pennylane/operation.py +++ b/pennylane/operation.py @@ -1323,30 +1323,6 @@ def decomposition(self) -> list["Operator"]: *self.parameters, wires=self.wires, **self.hyperparameters ) - @classproperty - def _has_plxpr_decomposition(cls) -> bool: - """Whether or not the Operator returns a defined plxpr decomposition.""" - - return ( - cls._compute_plxpr_decomposition != Operator._compute_plxpr_decomposition - or cls._plxpr_decomposition != Operator._plxpr_decomposition - ) - - def _plxpr_decomposition(self) -> "jax.core.Jaxpr": - """Representation of the operator as a plxpr decomposition.""" - - args = (*self.parameters, *self.wires) - jaxpr_decomp = qml.capture.make_plxpr( - partial(self._compute_plxpr_decomposition, **self.hyperparameters) - )(*args) - - return jaxpr_decomp - - @staticmethod - def _compute_plxpr_decomposition(*args, **hyperparameters): - """Experimental method to compute the plxpr decomposition of the operator.""" - raise DecompositionUndefinedError - @staticmethod def compute_decomposition( *params: TensorLike, @@ -1374,6 +1350,33 @@ def compute_decomposition( """ raise DecompositionUndefinedError + @classproperty + def has_plxpr_decomposition(cls) -> bool: + """Whether or not the Operator returns a defined plxpr decomposition.""" + return cls.compute_plxpr_decomposition != Operator.compute_plxpr_decomposition + + @staticmethod + def compute_plxpr_decomposition(*args, **hyperparameters) -> None: + r"""Experimental method to compute the dynamic decomposition of the operator with program capture enabled. + + When the program capture feature is enabled with ``qml.capture.enable()``, the decomposition of the operator + is computed with this method if it is defined. Otherwise, the :meth:`~.Operator.compute_decomposition` method is used. + + If this method is defined, the control flow operations within the method are recorded in the JAX representation + of the operator's decomposition. + + This method is experimental and subject to change. + + .. seealso:: :meth:`~.Operator.compute_decomposition`. + + Args: + *args (list): positional arguments passed to the operator, including trainable parameters and wires + **hyperparameters (dict): non-trainable hyperparameters of the operator, as stored in the ``hyperparameters`` attribute + + """ + + raise DecompositionUndefinedError + # pylint: disable=no-self-argument, comparison-with-callable @classproperty def has_diagonalizing_gates(cls) -> bool: diff --git a/pennylane/transforms/decompose.py b/pennylane/transforms/decompose.py index c2e3fb6a8f1..e73384bdc0b 100644 --- a/pennylane/transforms/decompose.py +++ b/pennylane/transforms/decompose.py @@ -157,7 +157,18 @@ def interpret_operation_eqn(self, eqn): with qml.QueuingManager.stop_recording(): op = eqn.primitive.impl(*invals, **eqn.params) if eqn.outvars[0].__class__.__name__ == "DropVar": - return self.decompose_operation(op) + + if op.has_plxpr_decomposition: + + args = (*op.parameters, *op.wires) + qml.capture.run_autograph(op.compute_plxpr_decomposition)( + *args, **op.hyperparameters + ) + + else: + + return self.decompose_operation(op) + return op # pylint: disable=unused-variable,missing-function-docstring @@ -184,93 +195,6 @@ def wrapper(*inner_args): DecomposeInterpreter, decompose_plxpr_to_plxpr = _get_plxpr_decompose() -@lru_cache -def _get_plxpr_dynamic_decompose(): # pylint: disable=missing-docstring - try: - # pylint: disable=import-outside-toplevel - # pylint: disable=unused-import - import jax - - from pennylane.capture.primitives import AbstractMeasurement, AbstractOperator - except ImportError: # pragma: no cover - return None, None - - # pylint: disable=redefined-outer-name - - class DynamicDecomposeInterpreter(qml.capture.PlxprInterpreter): - """ - Experimental Plxpr Interpreter for applying a dynamic decomposition to operations when program capture is enabled. - """ - - def eval_dynamic_decomposition(self, jaxpr_decomp: "jax.core.Jaxpr", consts, *args): - """ - Evaluate a dynamic decomposition of a Jaxpr. - - Args: - jaxpr_decomp (jax.core.Jaxpr): the Jaxpr to evaluate - *args: the arguments to use in the evaluation - """ - - for arg, invar in zip(args, jaxpr_decomp.invars, strict=True): - self._env[invar] = arg - - for const, constvar in zip(consts, jaxpr_decomp.constvars, strict=True): - self._env[constvar] = const - - for inner_eqn in jaxpr_decomp.eqns: - - custom_handler = self._primitive_registrations.get(inner_eqn.primitive, None) - - if custom_handler: - invals = [self.read(invar) for invar in inner_eqn.invars] - outvals = custom_handler(self, *invals, **inner_eqn.params) - - elif isinstance(inner_eqn.outvars[0].aval, AbstractOperator): - # This does not currently support nested decompositions - outvals = super().interpret_operation_eqn(inner_eqn) - elif isinstance(inner_eqn.outvars[0].aval, AbstractMeasurement): - outvals = super().interpret_measurement_eqn(inner_eqn) - else: - invals = [self.read(invar) for invar in inner_eqn.invars] - outvals = inner_eqn.primitive.bind(*invals, **inner_eqn.params) - - if not inner_eqn.primitive.multiple_results: - outvals = [outvals] - - for inner_outvar, inner_outval in zip(inner_eqn.outvars, outvals, strict=True): - self._env[inner_outvar] = inner_outval - - def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): - """ - Interpret an equation corresponding to an operator. - - Args: - eqn (jax.core.JaxprEqn): a jax equation for an operator. - """ - - invals = (self.read(invar) for invar in eqn.invars) - with qml.QueuingManager.stop_recording(): - op = eqn.primitive.impl(*invals, **eqn.params) - - if isinstance(eqn.outvars[0], jax.core.DropVar): - - if op._has_plxpr_decomposition: - jaxpr_decomp = op._plxpr_decomposition() - args = (*op.parameters, *op.wires) - return self.eval_dynamic_decomposition( - jaxpr_decomp.jaxpr, jaxpr_decomp.consts, *args - ) - - return super().interpret_operation(op) - - return op - - return DynamicDecomposeInterpreter - - -DynamicDecomposeInterpreter = _get_plxpr_dynamic_decompose() - - @partial(transform, plxpr_transform=decompose_plxpr_to_plxpr) def decompose(tape, gate_set=None, max_expansion=None): """Decomposes a quantum circuit into a user-specified gate set. diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index fc9d342ac98..f8bec181433 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for the ``DynamicDecomposeInterpreter`` class.""" +"""Unit tests for the ``DecomposeInterpreter`` class with dynamic decompositions.""" # pylint:disable=protected-access,unused-argument, wrong-import-position, no-value-for-parameter, too-few-public-methods import pytest @@ -21,7 +21,7 @@ from pennylane.capture.primitives import cond_prim, for_loop_prim, qnode_prim, while_loop_prim from pennylane.operation import Operation -from pennylane.transforms.decompose import DynamicDecomposeInterpreter +from pennylane.transforms.decompose import DecomposeInterpreter pytestmark = [pytest.mark.jax, pytest.mark.usefixtures("enable_disable_plxpr")] @@ -32,12 +32,13 @@ class SimpleCustomOp(Operation): num_wires = 1 num_params = 0 - def _init__(self, phi, wires, id=None): - super().__init__(phi, wires=wires, id=id) + def _init__(self, wires, id=None): + super().__init__(wires=wires, id=id) @staticmethod - def _compute_plxpr_decomposition(wires): - qml.Hadamard(wires=wires) + def compute_plxpr_decomposition(wires): + + return qml.Hadamard(wires=wires) const = jax.numpy.array(0.1) @@ -61,7 +62,7 @@ def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) @staticmethod - def _compute_plxpr_decomposition(*args, **hyperparameters): + def compute_plxpr_decomposition(*args, **hyperparameters): phi = args[0] wires = args[1:] @@ -89,7 +90,7 @@ def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) @staticmethod - def _compute_plxpr_decomposition(*args, **hyperparameters): + def compute_plxpr_decomposition(*args, **hyperparameters): phi = args[0] wires = args[1:] @@ -113,7 +114,7 @@ def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) @staticmethod - def _compute_plxpr_decomposition(phi, wires): + def compute_plxpr_decomposition(phi, wires): def true_fn(phi, wires): qml.RX(phi, wires=wires) @@ -134,7 +135,7 @@ def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) @staticmethod - def _compute_plxpr_decomposition(phi, wires): + def compute_plxpr_decomposition(phi, wires): @qml.for_loop(0, 3, 1) def loop_rx(i, phi): @@ -155,9 +156,12 @@ def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) @staticmethod - def _compute_plxpr_decomposition(phi, wires): + def compute_plxpr_decomposition(phi, wires): + + def while_f(i): + return i < 3 - @qml.while_loop(lambda i: i < 3) + @qml.while_loop(while_f) def loop_fn(i): qml.RX(phi, wires) return i + 1 @@ -177,7 +181,7 @@ def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) @staticmethod - def _compute_plxpr_decomposition(phi, wires): + def compute_plxpr_decomposition(phi, wires): def true_fn(phi, wires): @@ -191,7 +195,10 @@ def loop_rx(i, phi): def false_fn(phi, wires): - @qml.while_loop(lambda i: i < 3) + def while_f(i): + return i < 3 + + @qml.while_loop(while_f) def loop_fn(i): qml.RX(phi, wires) return i + 1 @@ -213,7 +220,7 @@ def __init__(self, phi, wires, id=None): super().__init__(phi, wires=wires, id=id) @staticmethod - def _compute_plxpr_decomposition(phi, wires): + def compute_plxpr_decomposition(phi, wires): if phi > 0.5: qml.RX(phi, wires=wires) @@ -229,12 +236,12 @@ def test_error_no_plxpr_decomposition(self): """Test that an error is raised if an operator does not have a plxpr decomposition.""" with pytest.raises(qml.operation.DecompositionUndefinedError): - qml.RX(0.1, 0)._compute_plxpr_decomposition() + qml.RX(0.1, 0).compute_plxpr_decomposition() def test_no_plxpr_decomposition(self): """Test that a function with a custom operation that does not have a plxpr decomposition is not decomposed.""" - @DynamicDecomposeInterpreter() + @DecomposeInterpreter() def f(x): qml.RY(x, wires=0) @@ -245,7 +252,7 @@ def f(x): def test_function_simple(self): """Test that a function with a custom operation is correctly decomposed.""" - @DynamicDecomposeInterpreter() + @DecomposeInterpreter() def f(): qml.RY(0.1, wires=0) SimpleCustomOp(wires=0) @@ -265,14 +272,14 @@ def f(): def test_qnode_simple(self): """Test that a QNode with a custom operation is correctly decomposed.""" - @DynamicDecomposeInterpreter() + @DecomposeInterpreter() @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=False) def circuit(): qml.RY(0.1, wires=0) SimpleCustomOp(wires=0) return qml.expval(qml.Z(0)) - jaxpr = jax.make_jaxpr(circuit)() + jaxpr = qml.capture.make_plxpr(circuit)() assert jaxpr.eqns[0].primitive == qnode_prim qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] @@ -295,7 +302,7 @@ def circuit_comparison(): def test_multi_wire(self, wires): """Test that a QNode with a multi-wire custom operation is correctly decomposed.""" - @DynamicDecomposeInterpreter() + @DecomposeInterpreter() @qml.qnode(device=qml.device("default.qubit", wires=4), autograph=False) def circuit(x, wires): CustomOpMultiWire(x, wires=wires) @@ -338,7 +345,7 @@ def circuit_comparison(x, wires): def test_qnode_const_hyperparams(self, wires, x): """Test that a QNode with a constant in the custom operation is correctly decomposed.""" - @DynamicDecomposeInterpreter() + @DecomposeInterpreter() @qml.qnode(device=qml.device("default.qubit", wires=4), autograph=False) def circuit(x, wires): CustomOpConstHyperparams(x, wires=wires) @@ -379,7 +386,7 @@ def circuit_comparison(x, wires): def test_qnode_cond(self, x, wire): """Test that a QNode with a conditional custom operation is correctly decomposed.""" - @DynamicDecomposeInterpreter() + @DecomposeInterpreter() @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=False) def circuit(x, wire): CustomOpCond(x, wires=wire) @@ -419,7 +426,7 @@ def false_fn(x, wire): def test_qnode_for_loop(self, wire): """Test that a QNode with a for loop custom operation is correctly decomposed.""" - @DynamicDecomposeInterpreter() + @DecomposeInterpreter() @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=False) def circuit(x, wire): CustomOpForLoop(x, wires=wire) @@ -454,7 +461,7 @@ def loop_rx(i, phi): def test_qnode_while_loop(self, wire): """Test that a QNode with a while loop custom operation is correctly decomposed.""" - @DynamicDecomposeInterpreter() + @DecomposeInterpreter() @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=False) def circuit(x, wire): CustomOpWhileLoop(x, wires=wire) @@ -489,7 +496,7 @@ def loop_fn(i): def test_qnode_nested_cond(self, x, wire): """Test that a QNode with a nested conditional custom operation is correctly decomposed.""" - @DynamicDecomposeInterpreter() + @DecomposeInterpreter() @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=False) def circuit(x, wire): CustomOpNestedCond(x, wires=wire) @@ -539,13 +546,13 @@ def loop_fn(i): def test_qnode_autograph(self, x, wire): """Test that a QNode with a nested conditional custom operation is correctly decomposed.""" - @DynamicDecomposeInterpreter() + @DecomposeInterpreter() @qml.qnode(device=qml.device("default.qubit", wires=2), autograph=False) def circuit(x, wire): CustomOpAutograph(x, wires=wire) return qml.expval(qml.Z(wires=wire)) - jaxpr = jax.make_jaxpr(circuit)(x, wire) + jaxpr = qml.capture.make_plxpr(circuit)(x, wire) assert jaxpr.eqns[0].primitive == qnode_prim qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] From e7a27751ea860599fbe0c33e0630a3bfed4099b2 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Wed, 5 Feb 2025 16:05:58 -0500 Subject: [PATCH 29/33] pylint --- pennylane/operation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pennylane/operation.py b/pennylane/operation.py index 7a71c158aab..f22f658ee6e 100644 --- a/pennylane/operation.py +++ b/pennylane/operation.py @@ -225,7 +225,6 @@ import warnings from collections.abc import Hashable, Iterable from enum import IntEnum -from functools import partial from typing import Any, Callable, Literal, Optional, Union import numpy as np From e9834af9b543328b361e96a3107b85f40375fc57 Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Mon, 10 Feb 2025 09:47:08 -0500 Subject: [PATCH 30/33] Suggestions from code review (more tests) --- .../test_capture_dynamic_decompositions.py | 82 ++++++++++++++++++- 1 file changed, 80 insertions(+), 2 deletions(-) diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index f8bec181433..812ce88b749 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -19,6 +19,9 @@ jax = pytest.importorskip("jax") +from functools import partial + +from pennylane.capture import expand_plxpr_transforms from pennylane.capture.primitives import cond_prim, for_loop_prim, qnode_prim, while_loop_prim from pennylane.operation import Operation from pennylane.transforms.decompose import DecomposeInterpreter @@ -168,8 +171,6 @@ def loop_fn(i): _ = loop_fn(0) - return qml.expval(qml.Z(0)) - class CustomOpNestedCond(Operation): """Custom operation that contains a nested conditional in its decomposition""" @@ -585,3 +586,80 @@ def circuit_comparison(x, wire): ) assert qml.math.allclose(*result, *result_comparison) + + +class TestExpandPlxprTransformsDynamicDecompositions: + """Unit tests for ``expand_plxpr_transforms`` with dynamic decompositions.""" + + def test_expand_plxpr_transforms_simple(self): + + @partial(qml.transforms.decompose) + def circuit(): + SimpleCustomOp(wires=0) + return qml.probs(wires=[0, 1]) + + jaxpr = jax.make_jaxpr(circuit)() + + assert jaxpr.eqns[0].primitive == qml.transforms.decompose._primitive + + transformed_f = expand_plxpr_transforms(circuit) + transformed_jaxpr = jax.make_jaxpr(transformed_f)() + + assert transformed_jaxpr.eqns[0].primitive == qml.Hadamard._primitive + assert ( + transformed_jaxpr.eqns[1].primitive == qml.measurements.ProbabilityMP._wires_primitive + ) + + def test_expand_plxpr_transforms_cond(self): + @partial(qml.transforms.decompose) + def circuit(): + CustomOpCond(0.5, wires=0) + return qml.probs(wires=[0, 1]) + + jaxpr = jax.make_jaxpr(circuit)() + + assert jaxpr.eqns[0].primitive == qml.transforms.decompose._primitive + + transformed_f = expand_plxpr_transforms(circuit) + transformed_jaxpr = jax.make_jaxpr(transformed_f)() + + assert transformed_jaxpr.eqns[0].primitive == cond_prim + assert ( + transformed_jaxpr.eqns[1].primitive == qml.measurements.ProbabilityMP._wires_primitive + ) + + def test_expand_plxpr_transforms_for_loop(self): + @partial(qml.transforms.decompose) + def circuit(): + CustomOpForLoop(0.5, wires=0) + return qml.probs(wires=[0, 1]) + + jaxpr = jax.make_jaxpr(circuit)() + + assert jaxpr.eqns[0].primitive == qml.transforms.decompose._primitive + + transformed_f = expand_plxpr_transforms(circuit) + transformed_jaxpr = jax.make_jaxpr(transformed_f)() + + assert transformed_jaxpr.eqns[0].primitive == for_loop_prim + assert ( + transformed_jaxpr.eqns[1].primitive == qml.measurements.ProbabilityMP._wires_primitive + ) + + def test_expand_plxpr_transforms_while_loop(self): + @partial(qml.transforms.decompose) + def circuit(): + CustomOpWhileLoop(0.5, wires=0) + return qml.probs(wires=[0, 1]) + + jaxpr = jax.make_jaxpr(circuit)() + + assert jaxpr.eqns[0].primitive == qml.transforms.decompose._primitive + + transformed_f = expand_plxpr_transforms(circuit) + transformed_jaxpr = jax.make_jaxpr(transformed_f)() + + assert transformed_jaxpr.eqns[0].primitive == while_loop_prim + assert ( + transformed_jaxpr.eqns[1].primitive == qml.measurements.ProbabilityMP._wires_primitive + ) From bd9473b3ad8dee91f923bb47bb22542db331f035 Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Mon, 10 Feb 2025 09:56:45 -0500 Subject: [PATCH 31/33] disabling wrong iimport order in test file (conflict between isort and pylint) --- tests/capture/transforms/test_capture_dynamic_decompositions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index 812ce88b749..3b0b401212a 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Unit tests for the ``DecomposeInterpreter`` class with dynamic decompositions.""" -# pylint:disable=protected-access,unused-argument, wrong-import-position, no-value-for-parameter, too-few-public-methods +# pylint:disable=protected-access,unused-argument, wrong-import-position, no-value-for-parameter, too-few-public-methods, wrong-import-order import pytest import pennylane as qml From 98fe17e0f64d26f313747e69b8514cd94b923b58 Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Mon, 10 Feb 2025 14:30:22 -0500 Subject: [PATCH 32/33] Suggestions from code review --- .../test_capture_dynamic_decompositions.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/tests/capture/transforms/test_capture_dynamic_decompositions.py b/tests/capture/transforms/test_capture_dynamic_decompositions.py index 7a39be3331f..84a7fc93ad2 100644 --- a/tests/capture/transforms/test_capture_dynamic_decompositions.py +++ b/tests/capture/transforms/test_capture_dynamic_decompositions.py @@ -337,11 +337,8 @@ def circuit_comparison(x, wires): return qml.expval(qml.Z(0)), qml.probs(wires=1), qml.var(qml.Z(2)), qml.state() comparison_result = circuit_comparison(0.5, wires) - - assert qml.math.allclose(result[0], comparison_result[0]) - assert qml.math.allclose(result[1], comparison_result[1]) - assert qml.math.allclose(result[2], comparison_result[2]) - assert qml.math.allclose(result[3], comparison_result[3]) + for res, comp in zip(result, comparison_result): + assert qml.math.allclose(res, comp) @pytest.mark.parametrize("autograph", [True, False]) @pytest.mark.parametrize("wires", [[0, 1, 2, 3], [2, 3, 1, 0]]) @@ -379,11 +376,8 @@ def circuit_comparison(x, wires): return qml.expval(qml.Z(0)), qml.probs(wires=1), qml.var(qml.Z(2)), qml.state() comparison_result = circuit_comparison(x, wires) - - assert qml.math.allclose(result[0], comparison_result[0]) - assert qml.math.allclose(result[1], comparison_result[1]) - assert qml.math.allclose(result[2], comparison_result[2]) - assert qml.math.allclose(result[3], comparison_result[3]) + for res, comp in zip(result, comparison_result): + assert qml.math.allclose(res, comp) @pytest.mark.parametrize("autograph", [True, False]) @pytest.mark.parametrize("wire", [0, 1]) From 902e4f869298cc3e9228278267b52d6ea1ac4164 Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Mon, 10 Feb 2025 15:35:05 -0500 Subject: [PATCH 33/33] Refactoring changelog with program capture entries --- doc/releases/changelog-dev.md | 75 +++++++++++++++++------------------ 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 4915da3955c..28223c99c1d 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -26,37 +26,6 @@ is greater than `0.4.28`. [(#6864)](https://github.com/PennyLaneAI/pennylane/pull/6864) -* Python control flow (`if/else`, `for`, `while`) is now supported when program capture is enabled by setting - `autograph=True` at the QNode level. - [(#6837)](https://github.com/PennyLaneAI/pennylane/pull/6837) - - ```python - qml.capture.enable() - - dev = qml.device("default.qubit", wires=[0, 1, 2]) - - @qml.qnode(dev, autograph=True) - def circuit(num_loops: int): - for i in range(num_loops): - if i % 2 == 0: - qml.H(i) - else: - qml.RX(1,i) - return qml.state() - ``` - - ```pycon - >>> print(qml.draw(circuit)(num_loops=3)) - 0: ──H────────┤ State - 1: ──RX(1.00)─┤ State - 2: ──H────────┤ State - >>> circuit(3) - Array([0.43879125+0.j , 0.43879125+0.j , - 0. -0.23971277j, 0. -0.23971277j, - 0.43879125+0.j , 0.43879125+0.j , - 0. -0.23971277j, 0. -0.23971277j], dtype=complex64) - ``` - * Added the `qml.workflow.construct_execution_config(qnode)(*args,**kwargs)` helper function. Users can now construct the execution configuration from a particular `QNode` instance. [(#6901)](https://github.com/PennyLaneAI/pennylane/pull/6901) @@ -147,19 +116,49 @@ * `null.qubit` can now execute jaxpr. [(#6924)](https://github.com/PennyLaneAI/pennylane/pull/6924) -* Autograph can now be used with custom operations defined outside of the pennylane namespace. - [(#6931)](https://github.com/PennyLaneAI/pennylane/pull/6931) - -

Capturing and representing hybrid programs

-* Add a `qml.capture.pause()` context manager for pausing program capture in an error-safe way. - [(#6911)](https://github.com/PennyLaneAI/pennylane/pull/6911) - * Implemented a `compute_plxpr_decomposition` method in the `qml.operation.Operator` class to apply dynamic decompositions with program capture enabled. [(#6859)](https://github.com/PennyLaneAI/pennylane/pull/6859) + * Autograph can now be used with custom operations defined outside of the pennylane namespace. + [(#6931)](https://github.com/PennyLaneAI/pennylane/pull/6931) + + * Add a `qml.capture.pause()` context manager for pausing program capture in an error-safe way. + [(#6911)](https://github.com/PennyLaneAI/pennylane/pull/6911) + +* Python control flow (`if/else`, `for`, `while`) is now supported when program capture is enabled by setting + `autograph=True` at the QNode level. + [(#6837)](https://github.com/PennyLaneAI/pennylane/pull/6837) + + ```python + qml.capture.enable() + + dev = qml.device("default.qubit", wires=[0, 1, 2]) + + @qml.qnode(dev, autograph=True) + def circuit(num_loops: int): + for i in range(num_loops): + if i % 2 == 0: + qml.H(i) + else: + qml.RX(1,i) + return qml.state() + ``` + + ```pycon + >>> print(qml.draw(circuit)(num_loops=3)) + 0: ──H────────┤ State + 1: ──RX(1.00)─┤ State + 2: ──H────────┤ State + >>> circuit(3) + Array([0.43879125+0.j , 0.43879125+0.j , + 0. -0.23971277j, 0. -0.23971277j, + 0.43879125+0.j , 0.43879125+0.j , + 0. -0.23971277j, 0. -0.23971277j], dtype=complex64) + ``` + * The higher order primitives in program capture can now accept inputs with abstract shapes. [(#6786)](https://github.com/PennyLaneAI/pennylane/pull/6786)