diff --git a/doc/development/plugins.rst b/doc/development/plugins.rst index 4a15f2c7519..086773d9bab 100644 --- a/doc/development/plugins.rst +++ b/doc/development/plugins.rst @@ -472,12 +472,15 @@ pieces of functionality: Note that these properties are only applicable to devices that provided derivatives or VJPs. If your device does not provide derivatives, you can safely ignore these properties. -The workflow options are ``use_device_gradient``, ``use_device_jacobian_product``, and ``grad_on_execution``. +The workflow options are ``use_device_gradient``, ``use_device_jacobian_product``, ``grad_on_execution``, +and ``convert_to_numpy``. ``use_device_gradient=True`` indicates that workflow should request derivatives from the device. ``grad_on_execution=True`` indicates a preference to use ``execute_and_compute_derivatives`` instead -of ``execute`` followed by ``compute_derivatives``. Finally, ``use_device_jacobian_product`` indicates +of ``execute`` followed by ``compute_derivatives``. ``use_device_jacobian_product`` indicates a request to call ``compute_vjp`` instead of ``compute_derivatives``. Note that if ``use_device_jacobian_product`` -is ``True``, this takes precedence over calculating the full jacobian. +is ``True``, this takes precedence over calculating the full jacobian. If the device can accept ML framework parameters, like +jax, ``convert_to_numpy=False`` should be specified. Then the parameters will not be converted, and special +interface-specific processing (like executing inside a ``jax.pure_callback`` when using ``jax.jit``) will be needed. >>> config = qml.devices.ExecutionConfig(gradient_method="adjoint") >>> processed_config = qml.device('default.qubit').setup_execution_config(config) @@ -487,6 +490,8 @@ True True >>> processed_config.grad_on_execution True +>>> processed_config.convert_to_numpy +True Execution --------- diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 1810651f858..df4ac71dd44 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -6,6 +6,12 @@

Improvements 🛠

+* Finite shot and parameter-shift executions on `default.qubit` can now + be natively jitted end-to-end, leading to performance improvements. + Devices can now configure whether or not ML framework data is sent to them + via an `ExecutionConfig.convert_to_numpy` parameter. + [(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788) + * The coefficients of observables now have improved differentiability. [(#6598)](https://github.com/PennyLaneAI/pennylane/pull/6598) diff --git a/pennylane/devices/default_mixed.py b/pennylane/devices/default_mixed.py index 9e18ba09249..d55f9eec623 100644 --- a/pennylane/devices/default_mixed.py +++ b/pennylane/devices/default_mixed.py @@ -1008,8 +1008,8 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio "best", } updated_values["grad_on_execution"] = False - if not execution_config.gradient_method in {"best", "backprop", None}: - execution_config.interface = None + if execution_config.gradient_method not in {"best", "backprop", None}: + updated_values["interface"] = None # Add device options updated_values["device_options"] = dict(execution_config.device_options) # copy diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index fcdb25d2783..48f8ffbc686 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -161,6 +161,14 @@ def _conditional_broastcast_expand(tape): return (tape,), null_postprocessing +@qml.transform +def no_counts(tape): + """Throws an error on counts measurements.""" + if any(isinstance(mp, qml.measurements.CountsMP) for mp in tape.measurements): + raise NotImplementedError("The JAX-JIT interface doesn't support qml.counts.") + return (tape,), null_postprocessing + + @qml.transform def adjoint_state_measurements( tape: QuantumScript, device_vjp=False @@ -535,6 +543,8 @@ def preprocess( config = self._setup_execution_config(execution_config) transform_program = TransformProgram() + if config.interface == qml.math.Interface.JAX_JIT: + transform_program.add_transform(no_counts) transform_program.add_transform(validate_device_wires, self.wires, name=self.name) transform_program.add_transform( mid_circuit_measurements, device=self, mcm_config=config.mcm_config @@ -581,6 +591,13 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio """ updated_values = {} + jax_interfaces = {qml.math.Interface.JAX, qml.math.Interface.JAX_JIT} + updated_values["convert_to_numpy"] = ( + execution_config.interface not in jax_interfaces + or execution_config.gradient_method == "adjoint" + # need numpy to use caching, and need caching higher order derivatives + or execution_config.derivative_order > 1 + ) for option in execution_config.device_options: if option not in self._device_options: raise qml.DeviceError(f"device option {option} not present on {self}") @@ -616,7 +633,6 @@ def execute( execution_config: ExecutionConfig = DefaultExecutionConfig, ) -> Union[Result, ResultBatch]: self.reset_prng_key() - max_workers = execution_config.device_options.get("max_workers", self._max_workers) self._state_cache = {} if execution_config.use_device_jacobian_product else None interface = ( diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index 2cf78cee394..17c9363b76f 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -17,7 +17,7 @@ from dataclasses import dataclass, field from typing import Optional, Union -from pennylane.math import get_canonical_interface_name +from pennylane.math import Interface, get_canonical_interface_name from pennylane.transforms.core import TransformDispatcher @@ -87,7 +87,7 @@ class ExecutionConfig: device_options: Optional[dict] = None """Various options for the device executing a quantum circuit""" - interface: Optional[str] = None + interface: Interface = Interface.NUMPY """The machine learning framework to use""" derivative_order: int = 1 @@ -96,6 +96,13 @@ class ExecutionConfig: mcm_config: MCMConfig = field(default_factory=MCMConfig) """Configuration options for handling mid-circuit measurements""" + convert_to_numpy: bool = True + """Whether or not to convert parameters to numpy before execution. + + If ``False`` and using the jax-jit, no pure callback will occur and the device + execution itself will be jitted. + """ + def __post_init__(self): """ Validate the configured execution options. @@ -124,7 +131,7 @@ def __post_init__(self): ) if isinstance(self.mcm_config, dict): - self.mcm_config = MCMConfig(**self.mcm_config) + self.mcm_config = MCMConfig(**self.mcm_config) # pylint: disable=not-a-mapping elif not isinstance(self.mcm_config, MCMConfig): raise ValueError(f"Got invalid type {type(self.mcm_config)} for 'mcm_config'") diff --git a/pennylane/devices/qubit/sampling.py b/pennylane/devices/qubit/sampling.py index 06ae78b5708..527e3296a5c 100644 --- a/pennylane/devices/qubit/sampling.py +++ b/pennylane/devices/qubit/sampling.py @@ -580,6 +580,6 @@ def _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key=None, _, key = jax_random_split(prng_key) samples = jax.random.choice(key, basis_states, shape=(shots,), p=probs) - powers_of_two = 1 << jnp.arange(num_wires, dtype=jnp.int64)[::-1] + powers_of_two = 1 << jnp.arange(num_wires, dtype=int)[::-1] states_sampled_base_ten = samples[..., None] & powers_of_two - return (states_sampled_base_ten > 0).astype(jnp.int64) + return (states_sampled_base_ten > 0).astype(int) diff --git a/pennylane/workflow/_setup_transform_program.py b/pennylane/workflow/_setup_transform_program.py index 5866c8f7bf4..67cfab372b9 100644 --- a/pennylane/workflow/_setup_transform_program.py +++ b/pennylane/workflow/_setup_transform_program.py @@ -117,7 +117,8 @@ def _setup_transform_program( # changing this set of conditions causes a bunch of tests to break. interface_data_supported = ( - resolved_execution_config.interface is Interface.NUMPY + (not resolved_execution_config.convert_to_numpy) + or resolved_execution_config.interface is Interface.NUMPY or resolved_execution_config.gradient_method == "backprop" or ( getattr(device, "short_name", "") == "default.mixed" diff --git a/pennylane/workflow/interfaces/jax.py b/pennylane/workflow/interfaces/jax.py index aaa2c55dd1e..a8f5b0bea32 100644 --- a/pennylane/workflow/interfaces/jax.py +++ b/pennylane/workflow/interfaces/jax.py @@ -186,7 +186,7 @@ def _to_jax(result: qml.typing.ResultBatch) -> qml.typing.ResultBatch: return result if isinstance(result, (list, tuple)): return tuple(_to_jax(r) for r in result) - return jnp.array(result) + return result if qml.math.get_interface(result) == "jax" else jnp.array(result) def _execute_wrapper(params, tapes, execute_fn, jpc) -> ResultBatch: diff --git a/pennylane/workflow/interfaces/jax_jit.py b/pennylane/workflow/interfaces/jax_jit.py index 3cd5779a5de..296afc4408f 100644 --- a/pennylane/workflow/interfaces/jax_jit.py +++ b/pennylane/workflow/interfaces/jax_jit.py @@ -59,7 +59,7 @@ def _to_jax(result: qml.typing.ResultBatch) -> qml.typing.ResultBatch: """ if isinstance(result, dict): - return {key: jnp.array(value) for key, value in result.items()} + return {key: _to_jax(value) for key, value in result.items()} if isinstance(result, (list, tuple)): return tuple(_to_jax(r) for r in result) return jnp.array(result) diff --git a/pennylane/workflow/run.py b/pennylane/workflow/run.py index 456d04f1ad7..9f1096ab32b 100644 --- a/pennylane/workflow/run.py +++ b/pennylane/workflow/run.py @@ -204,7 +204,7 @@ def _get_ml_boundary_execute( elif interface == Interface.TORCH: from .interfaces.torch import execute as ml_boundary - elif interface == Interface.JAX_JIT: + elif interface == Interface.JAX_JIT and resolved_execution_config.convert_to_numpy: from .interfaces.jax_jit import jax_jit_jvp_execute as ml_boundary else: # interface is jax diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py index dc6dd8fb67e..42faaa10809 100644 --- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py +++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py @@ -389,11 +389,7 @@ def func(x, y, z): results1 = func1(*params) jaxpr = str(jax.make_jaxpr(func)(*params)) - if diff_method == "best": - assert "pure_callback" in jaxpr - pytest.xfail("QNode with diff_method='best' cannot be compiled with jax.jit.") - else: - assert "pure_callback" not in jaxpr + assert "pure_callback" not in jaxpr func2 = jax.jit(func) results2 = func2(*params) diff --git a/tests/devices/default_qubit/test_default_qubit_preprocessing.py b/tests/devices/default_qubit/test_default_qubit_preprocessing.py index 40f45a9383e..59a9098f7ce 100644 --- a/tests/devices/default_qubit/test_default_qubit_preprocessing.py +++ b/tests/devices/default_qubit/test_default_qubit_preprocessing.py @@ -141,6 +141,32 @@ def circuit(x): assert dev.tracker.totals["execute_and_derivative_batches"] == 1 + @pytest.mark.parametrize("interface", ("jax", "jax-jit")) + def test_not_convert_to_numpy_with_jax(self, interface): + """Test that we will not convert to numpy when working with jax.""" + + dev = qml.device("default.qubit") + config = qml.devices.ExecutionConfig( + gradient_method=qml.gradients.param_shift, interface=interface + ) + processed = dev.setup_execution_config(config) + assert not processed.convert_to_numpy + + def test_convert_to_numpy_with_adjoint(self): + """Test that we will convert to numpy with adjoint.""" + config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface="jax-jit") + dev = qml.device("default.qubit") + processed = dev.setup_execution_config(config) + assert processed.convert_to_numpy + + @pytest.mark.parametrize("interface", ("autograd", "torch", "tf")) + def test_convert_to_numpy_non_jax(self, interface): + """Test that other interfaces are still converted to numpy.""" + config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface=interface) + dev = qml.device("default.qubit") + processed = dev.setup_execution_config(config) + assert processed.convert_to_numpy + # pylint: disable=too-few-public-methods class TestPreprocessing: diff --git a/tests/gradients/core/test_pulse_gradient.py b/tests/gradients/core/test_pulse_gradient.py index 5fd9bf34937..d0aba1f4582 100644 --- a/tests/gradients/core/test_pulse_gradient.py +++ b/tests/gradients/core/test_pulse_gradient.py @@ -1485,7 +1485,6 @@ def circuit(params): assert qml.math.allclose(j[0], e, atol=tol, rtol=0.0) jax.clear_caches() - @pytest.mark.xfail @pytest.mark.parametrize("num_split_times", [1, 2]) @pytest.mark.parametrize("time_interface", ["python", "numpy", "jax"]) def test_simple_qnode_jit(self, num_split_times, time_interface): diff --git a/tests/param_shift_dev.py b/tests/param_shift_dev.py index 12fd11eea16..7c7442161e2 100644 --- a/tests/param_shift_dev.py +++ b/tests/param_shift_dev.py @@ -34,6 +34,7 @@ def preprocess(self, execution_config=qml.devices.DefaultExecutionConfig): execution_config, use_device_jacobian_product=True ) program, config = super().preprocess(execution_config) + config = dataclasses.replace(config, convert_to_numpy=True) program.add_transform(qml.transform(qml.gradients.param_shift.expand_transform)) return program, config diff --git a/tests/templates/test_state_preparations/test_mottonen_state_prep.py b/tests/templates/test_state_preparations/test_mottonen_state_prep.py index 1a9f0aa6219..885def86e60 100644 --- a/tests/templates/test_state_preparations/test_mottonen_state_prep.py +++ b/tests/templates/test_state_preparations/test_mottonen_state_prep.py @@ -417,11 +417,13 @@ def circuit(state): @pytest.mark.jax -@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.05)]) -def test_jacobians_with_and_without_jit_match(shots, atol, seed): +def test_jacobians_with_and_without_jit_match(seed): """Test that the Jacobian of the circuit is the same with and without jit.""" import jax + shots = None + atol = 0.005 + dev = qml.device("default.qubit", shots=shots, seed=seed) dev_no_shots = qml.device("default.qubit", shots=None) @@ -433,7 +435,7 @@ def circuit(coeffs): circuit_ps = qml.QNode(circuit, dev, diff_method="parameter-shift") circuit_exact = qml.QNode(circuit, dev_no_shots) - params = jax.numpy.array([0.5, 0.5, 0.5, 0.5]) + params = jax.numpy.array([0.5, 0.5, 0.5, 0.5], dtype=jax.numpy.float64) jac_exact_fn = jax.jacobian(circuit_exact) jac_fd_fn = jax.jacobian(circuit_fd) jac_fd_fn_jit = jax.jit(jac_fd_fn) diff --git a/tests/test_qnode.py b/tests/test_qnode.py index d79d162d5a4..e5a92989e7d 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -613,12 +613,13 @@ def func(x, y): assert tape.measurements == contents[3:] @pytest.mark.jax - def test_jit_counts_raises_error(self): + @pytest.mark.parametrize("dev_name", ("default.qubit", "reference.qubit")) + def test_jit_counts_raises_error(self, dev_name): """Test that returning counts in a quantum function with trainable parameters while jitting raises an error.""" import jax - dev = qml.device("default.qubit", wires=2, shots=5) + dev = qml.device(dev_name, wires=2, shots=5) def circuit1(param): qml.Hadamard(0) @@ -632,7 +633,7 @@ def circuit1(param): with pytest.raises( NotImplementedError, match="The JAX-JIT interface doesn't support qml.counts." ): - jitted_qnode1(0.123) + _ = jitted_qnode1(0.123) # Test with qnode decorator syntax @qml.qnode(dev) diff --git a/tests/workflow/interfaces/execute/test_jax.py b/tests/workflow/interfaces/execute/test_jax.py index 7a7c9468ca5..db056dcd0c0 100644 --- a/tests/workflow/interfaces/execute/test_jax.py +++ b/tests/workflow/interfaces/execute/test_jax.py @@ -693,10 +693,12 @@ def test_max_diff(self, tol): def cost_fn(x): ops = [qml.RX(x[0], 0), qml.RY(x[1], 1), qml.CNOT((0, 1))] - tape1 = qml.tape.QuantumScript(ops, [qml.var(qml.PauliZ(0) @ qml.PauliX(1))]) + tape1 = qml.tape.QuantumScript( + ops, [qml.var(qml.PauliZ(0) @ qml.PauliX(1))], shots=50000 + ) ops2 = [qml.RX(x[0], 0), qml.RY(x[0], 1), qml.CNOT((0, 1))] - tape2 = qml.tape.QuantumScript(ops2, [qml.probs(wires=1)]) + tape2 = qml.tape.QuantumScript(ops2, [qml.probs(wires=1)], shots=50000) result = execute([tape1, tape2], dev, diff_method=param_shift, max_diff=1) return result[0] + result[1][0] @@ -704,13 +706,13 @@ def cost_fn(x): res = cost_fn(params) x, y = params expected = 0.5 * (3 + jnp.cos(x) ** 2 * jnp.cos(2 * y)) - assert np.allclose(res, expected, atol=tol, rtol=0) + assert np.allclose(res, expected, atol=2e-2, rtol=0) res = jax.grad(cost_fn)(params) expected = jnp.array( [-jnp.cos(x) * jnp.cos(2 * y) * jnp.sin(x), -jnp.cos(x) ** 2 * jnp.sin(2 * y)] ) - assert np.allclose(res, expected, atol=tol, rtol=0) + assert np.allclose(res, expected, atol=2e-2, rtol=0) res = jax.jacobian(jax.grad(cost_fn))(params) expected = jnp.zeros([2, 2]) diff --git a/tests/workflow/interfaces/execute/test_jax_jit.py b/tests/workflow/interfaces/execute/test_jax_jit.py index ce6a9ef27f4..95b056cf476 100644 --- a/tests/workflow/interfaces/execute/test_jax_jit.py +++ b/tests/workflow/interfaces/execute/test_jax_jit.py @@ -886,14 +886,17 @@ def cost(x, y, device, interface, ek): class TestJitAllCounts: + @pytest.mark.parametrize( + "device_name", (pytest.param("default.qubit", marks=pytest.mark.xfail), "reference.qubit") + ) @pytest.mark.parametrize("counts_wires", (None, (0, 1))) - def test_jit_allcounts(self, counts_wires): + def test_jit_allcounts(self, device_name, counts_wires): """Test jitting with counts with all_outcomes == True.""" tape = qml.tape.QuantumScript( [qml.RX(0, 0), qml.I(1)], [qml.counts(wires=counts_wires, all_outcomes=True)], shots=50 ) - device = qml.device("default.qubit") + device = qml.device(device_name, wires=2) res = jax.jit(qml.execute, static_argnums=(1, 2))( (tape,), device, qml.gradients.param_shift @@ -904,7 +907,14 @@ def test_jit_allcounts(self, counts_wires): for val in ["01", "10", "11"]: assert qml.math.allclose(res[val], 0) - def test_jit_allcounts_broadcasting(self): + @pytest.mark.parametrize( + "device_name", + ( + pytest.param("default.qubit", marks=pytest.mark.xfail), + pytest.param("reference.qubit", marks=pytest.mark.xfail), + ), + ) + def test_jit_allcounts_broadcasting(self, device_name): """Test jitting with counts with all_outcomes == True.""" tape = qml.tape.QuantumScript( @@ -912,7 +922,7 @@ def test_jit_allcounts_broadcasting(self): [qml.counts(wires=(0, 1), all_outcomes=True)], shots=50, ) - device = qml.device("default.qubit") + device = qml.device(device_name, wires=2) res = jax.jit(qml.execute, static_argnums=(1, 2))( (tape,), device, qml.gradients.param_shift @@ -927,7 +937,6 @@ def test_jit_allcounts_broadcasting(self): assert qml.math.allclose(ri[val], 0) -@pytest.mark.xfail(reason="Need to figure out how to handle this case in a less ambiguous manner") def test_diff_method_None_jit(): """Test that jitted execution works when `diff_method=None`.""" diff --git a/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py b/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py index 544da89f73d..862806999c6 100644 --- a/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py +++ b/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py @@ -3185,7 +3185,8 @@ def test_complex64_return(self, diff_method): jax.config.update("jax_enable_x64", False) try: - tol = 2e-2 if diff_method == "finite-diff" else 1e-6 + # finite diff with float32 ... + tol = 5e-2 if diff_method == "finite-diff" else 1e-6 @jax.jit @qml.qnode(qml.device("default.qubit", wires=1), diff_method=diff_method) diff --git a/tests/workflow/interfaces/test_jacobian_products.py b/tests/workflow/interfaces/test_jacobian_products.py index 9a090d3af0a..8a3d6745686 100644 --- a/tests/workflow/interfaces/test_jacobian_products.py +++ b/tests/workflow/interfaces/test_jacobian_products.py @@ -136,7 +136,7 @@ def test_device_jacobians_repr(self): r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}," r" device_options={}, interface=, derivative_order=1," - r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>" + r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True)>" ) assert repr(jpc) == expected @@ -155,7 +155,7 @@ def test_device_jacobian_products_repr(self): r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}, device_options={}," r" interface=, derivative_order=1," - r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>" + r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True)>" ) assert repr(jpc) == expected diff --git a/tests/workflow/test_setup_transform_program.py b/tests/workflow/test_setup_transform_program.py index c81ed75ae2f..79207491014 100644 --- a/tests/workflow/test_setup_transform_program.py +++ b/tests/workflow/test_setup_transform_program.py @@ -140,9 +140,7 @@ def test_prune_dynamic_transform_warning_raised(): def test_interface_data_not_supported(): """Test that convert_to_numpy_parameters transform is correctly added.""" - config = ExecutionConfig() - config.interface = "autograd" - config.gradient_method = "adjoint" + config = ExecutionConfig(interface="autograd", gradient_method="adjoint") device = qml.device("default.qubit") user_transform_program = TransformProgram() @@ -154,10 +152,8 @@ def test_interface_data_not_supported(): def test_interface_data_supported(): """Test that convert_to_numpy_parameters transform is not added for these cases.""" - config = ExecutionConfig() + config = ExecutionConfig(interface="autograd", gradient_method=None) - config.interface = "autograd" - config.gradient_method = None device = qml.device("default.mixed", wires=1) user_transform_program = TransformProgram() @@ -165,10 +161,8 @@ def test_interface_data_supported(): assert qml.transforms.convert_to_numpy_parameters not in inner_tp - config = ExecutionConfig() + config = ExecutionConfig(interface="autograd", gradient_method="backprop") - config.interface = "autograd" - config.gradient_method = "backprop" device = qml.device("default.qubit") user_transform_program = TransformProgram() @@ -176,10 +170,8 @@ def test_interface_data_supported(): assert qml.transforms.convert_to_numpy_parameters not in inner_tp - config = ExecutionConfig() + config = ExecutionConfig(interface=None, gradient_method="backprop") - config.interface = None - config.gradient_method = "backprop" device = qml.device("default.qubit") user_transform_program = TransformProgram() @@ -187,6 +179,13 @@ def test_interface_data_supported(): assert qml.transforms.convert_to_numpy_parameters not in inner_tp + config = ExecutionConfig( + convert_to_numpy=False, interface="jax", gradient_method=qml.gradients.param_shift + ) + + _, inner_tp = _setup_transform_program(TransformProgram(), device, config) + assert qml.transforms.convert_to_numpy_parameters not in inner_tp + def test_cache_handling(): """Test that caching is handled correctly."""