diff --git a/doc/development/plugins.rst b/doc/development/plugins.rst
index 4a15f2c7519..086773d9bab 100644
--- a/doc/development/plugins.rst
+++ b/doc/development/plugins.rst
@@ -472,12 +472,15 @@ pieces of functionality:
Note that these properties are only applicable to devices that provided derivatives or VJPs. If your device
does not provide derivatives, you can safely ignore these properties.
-The workflow options are ``use_device_gradient``, ``use_device_jacobian_product``, and ``grad_on_execution``.
+The workflow options are ``use_device_gradient``, ``use_device_jacobian_product``, ``grad_on_execution``,
+and ``convert_to_numpy``.
``use_device_gradient=True`` indicates that workflow should request derivatives from the device.
``grad_on_execution=True`` indicates a preference to use ``execute_and_compute_derivatives`` instead
-of ``execute`` followed by ``compute_derivatives``. Finally, ``use_device_jacobian_product`` indicates
+of ``execute`` followed by ``compute_derivatives``. ``use_device_jacobian_product`` indicates
a request to call ``compute_vjp`` instead of ``compute_derivatives``. Note that if ``use_device_jacobian_product``
-is ``True``, this takes precedence over calculating the full jacobian.
+is ``True``, this takes precedence over calculating the full jacobian. If the device can accept ML framework parameters, like
+jax, ``convert_to_numpy=False`` should be specified. Then the parameters will not be converted, and special
+interface-specific processing (like executing inside a ``jax.pure_callback`` when using ``jax.jit``) will be needed.
>>> config = qml.devices.ExecutionConfig(gradient_method="adjoint")
>>> processed_config = qml.device('default.qubit').setup_execution_config(config)
@@ -487,6 +490,8 @@ True
True
>>> processed_config.grad_on_execution
True
+>>> processed_config.convert_to_numpy
+True
Execution
---------
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 1810651f858..df4ac71dd44 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -6,6 +6,12 @@
Improvements ðŸ›
+* Finite shot and parameter-shift executions on `default.qubit` can now
+ be natively jitted end-to-end, leading to performance improvements.
+ Devices can now configure whether or not ML framework data is sent to them
+ via an `ExecutionConfig.convert_to_numpy` parameter.
+ [(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)
+
* The coefficients of observables now have improved differentiability.
[(#6598)](https://github.com/PennyLaneAI/pennylane/pull/6598)
diff --git a/pennylane/devices/default_mixed.py b/pennylane/devices/default_mixed.py
index 9e18ba09249..d55f9eec623 100644
--- a/pennylane/devices/default_mixed.py
+++ b/pennylane/devices/default_mixed.py
@@ -1008,8 +1008,8 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio
"best",
}
updated_values["grad_on_execution"] = False
- if not execution_config.gradient_method in {"best", "backprop", None}:
- execution_config.interface = None
+ if execution_config.gradient_method not in {"best", "backprop", None}:
+ updated_values["interface"] = None
# Add device options
updated_values["device_options"] = dict(execution_config.device_options) # copy
diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py
index fcdb25d2783..48f8ffbc686 100644
--- a/pennylane/devices/default_qubit.py
+++ b/pennylane/devices/default_qubit.py
@@ -161,6 +161,14 @@ def _conditional_broastcast_expand(tape):
return (tape,), null_postprocessing
+@qml.transform
+def no_counts(tape):
+ """Throws an error on counts measurements."""
+ if any(isinstance(mp, qml.measurements.CountsMP) for mp in tape.measurements):
+ raise NotImplementedError("The JAX-JIT interface doesn't support qml.counts.")
+ return (tape,), null_postprocessing
+
+
@qml.transform
def adjoint_state_measurements(
tape: QuantumScript, device_vjp=False
@@ -535,6 +543,8 @@ def preprocess(
config = self._setup_execution_config(execution_config)
transform_program = TransformProgram()
+ if config.interface == qml.math.Interface.JAX_JIT:
+ transform_program.add_transform(no_counts)
transform_program.add_transform(validate_device_wires, self.wires, name=self.name)
transform_program.add_transform(
mid_circuit_measurements, device=self, mcm_config=config.mcm_config
@@ -581,6 +591,13 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio
"""
updated_values = {}
+ jax_interfaces = {qml.math.Interface.JAX, qml.math.Interface.JAX_JIT}
+ updated_values["convert_to_numpy"] = (
+ execution_config.interface not in jax_interfaces
+ or execution_config.gradient_method == "adjoint"
+ # need numpy to use caching, and need caching higher order derivatives
+ or execution_config.derivative_order > 1
+ )
for option in execution_config.device_options:
if option not in self._device_options:
raise qml.DeviceError(f"device option {option} not present on {self}")
@@ -616,7 +633,6 @@ def execute(
execution_config: ExecutionConfig = DefaultExecutionConfig,
) -> Union[Result, ResultBatch]:
self.reset_prng_key()
-
max_workers = execution_config.device_options.get("max_workers", self._max_workers)
self._state_cache = {} if execution_config.use_device_jacobian_product else None
interface = (
diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py
index 2cf78cee394..17c9363b76f 100644
--- a/pennylane/devices/execution_config.py
+++ b/pennylane/devices/execution_config.py
@@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from typing import Optional, Union
-from pennylane.math import get_canonical_interface_name
+from pennylane.math import Interface, get_canonical_interface_name
from pennylane.transforms.core import TransformDispatcher
@@ -87,7 +87,7 @@ class ExecutionConfig:
device_options: Optional[dict] = None
"""Various options for the device executing a quantum circuit"""
- interface: Optional[str] = None
+ interface: Interface = Interface.NUMPY
"""The machine learning framework to use"""
derivative_order: int = 1
@@ -96,6 +96,13 @@ class ExecutionConfig:
mcm_config: MCMConfig = field(default_factory=MCMConfig)
"""Configuration options for handling mid-circuit measurements"""
+ convert_to_numpy: bool = True
+ """Whether or not to convert parameters to numpy before execution.
+
+ If ``False`` and using the jax-jit, no pure callback will occur and the device
+ execution itself will be jitted.
+ """
+
def __post_init__(self):
"""
Validate the configured execution options.
@@ -124,7 +131,7 @@ def __post_init__(self):
)
if isinstance(self.mcm_config, dict):
- self.mcm_config = MCMConfig(**self.mcm_config)
+ self.mcm_config = MCMConfig(**self.mcm_config) # pylint: disable=not-a-mapping
elif not isinstance(self.mcm_config, MCMConfig):
raise ValueError(f"Got invalid type {type(self.mcm_config)} for 'mcm_config'")
diff --git a/pennylane/devices/qubit/sampling.py b/pennylane/devices/qubit/sampling.py
index 06ae78b5708..527e3296a5c 100644
--- a/pennylane/devices/qubit/sampling.py
+++ b/pennylane/devices/qubit/sampling.py
@@ -580,6 +580,6 @@ def _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key=None,
_, key = jax_random_split(prng_key)
samples = jax.random.choice(key, basis_states, shape=(shots,), p=probs)
- powers_of_two = 1 << jnp.arange(num_wires, dtype=jnp.int64)[::-1]
+ powers_of_two = 1 << jnp.arange(num_wires, dtype=int)[::-1]
states_sampled_base_ten = samples[..., None] & powers_of_two
- return (states_sampled_base_ten > 0).astype(jnp.int64)
+ return (states_sampled_base_ten > 0).astype(int)
diff --git a/pennylane/workflow/_setup_transform_program.py b/pennylane/workflow/_setup_transform_program.py
index 5866c8f7bf4..67cfab372b9 100644
--- a/pennylane/workflow/_setup_transform_program.py
+++ b/pennylane/workflow/_setup_transform_program.py
@@ -117,7 +117,8 @@ def _setup_transform_program(
# changing this set of conditions causes a bunch of tests to break.
interface_data_supported = (
- resolved_execution_config.interface is Interface.NUMPY
+ (not resolved_execution_config.convert_to_numpy)
+ or resolved_execution_config.interface is Interface.NUMPY
or resolved_execution_config.gradient_method == "backprop"
or (
getattr(device, "short_name", "") == "default.mixed"
diff --git a/pennylane/workflow/interfaces/jax.py b/pennylane/workflow/interfaces/jax.py
index aaa2c55dd1e..a8f5b0bea32 100644
--- a/pennylane/workflow/interfaces/jax.py
+++ b/pennylane/workflow/interfaces/jax.py
@@ -186,7 +186,7 @@ def _to_jax(result: qml.typing.ResultBatch) -> qml.typing.ResultBatch:
return result
if isinstance(result, (list, tuple)):
return tuple(_to_jax(r) for r in result)
- return jnp.array(result)
+ return result if qml.math.get_interface(result) == "jax" else jnp.array(result)
def _execute_wrapper(params, tapes, execute_fn, jpc) -> ResultBatch:
diff --git a/pennylane/workflow/interfaces/jax_jit.py b/pennylane/workflow/interfaces/jax_jit.py
index 3cd5779a5de..296afc4408f 100644
--- a/pennylane/workflow/interfaces/jax_jit.py
+++ b/pennylane/workflow/interfaces/jax_jit.py
@@ -59,7 +59,7 @@ def _to_jax(result: qml.typing.ResultBatch) -> qml.typing.ResultBatch:
"""
if isinstance(result, dict):
- return {key: jnp.array(value) for key, value in result.items()}
+ return {key: _to_jax(value) for key, value in result.items()}
if isinstance(result, (list, tuple)):
return tuple(_to_jax(r) for r in result)
return jnp.array(result)
diff --git a/pennylane/workflow/run.py b/pennylane/workflow/run.py
index 456d04f1ad7..9f1096ab32b 100644
--- a/pennylane/workflow/run.py
+++ b/pennylane/workflow/run.py
@@ -204,7 +204,7 @@ def _get_ml_boundary_execute(
elif interface == Interface.TORCH:
from .interfaces.torch import execute as ml_boundary
- elif interface == Interface.JAX_JIT:
+ elif interface == Interface.JAX_JIT and resolved_execution_config.convert_to_numpy:
from .interfaces.jax_jit import jax_jit_jvp_execute as ml_boundary
else: # interface is jax
diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py
index dc6dd8fb67e..42faaa10809 100644
--- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py
+++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py
@@ -389,11 +389,7 @@ def func(x, y, z):
results1 = func1(*params)
jaxpr = str(jax.make_jaxpr(func)(*params))
- if diff_method == "best":
- assert "pure_callback" in jaxpr
- pytest.xfail("QNode with diff_method='best' cannot be compiled with jax.jit.")
- else:
- assert "pure_callback" not in jaxpr
+ assert "pure_callback" not in jaxpr
func2 = jax.jit(func)
results2 = func2(*params)
diff --git a/tests/devices/default_qubit/test_default_qubit_preprocessing.py b/tests/devices/default_qubit/test_default_qubit_preprocessing.py
index 40f45a9383e..59a9098f7ce 100644
--- a/tests/devices/default_qubit/test_default_qubit_preprocessing.py
+++ b/tests/devices/default_qubit/test_default_qubit_preprocessing.py
@@ -141,6 +141,32 @@ def circuit(x):
assert dev.tracker.totals["execute_and_derivative_batches"] == 1
+ @pytest.mark.parametrize("interface", ("jax", "jax-jit"))
+ def test_not_convert_to_numpy_with_jax(self, interface):
+ """Test that we will not convert to numpy when working with jax."""
+
+ dev = qml.device("default.qubit")
+ config = qml.devices.ExecutionConfig(
+ gradient_method=qml.gradients.param_shift, interface=interface
+ )
+ processed = dev.setup_execution_config(config)
+ assert not processed.convert_to_numpy
+
+ def test_convert_to_numpy_with_adjoint(self):
+ """Test that we will convert to numpy with adjoint."""
+ config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface="jax-jit")
+ dev = qml.device("default.qubit")
+ processed = dev.setup_execution_config(config)
+ assert processed.convert_to_numpy
+
+ @pytest.mark.parametrize("interface", ("autograd", "torch", "tf"))
+ def test_convert_to_numpy_non_jax(self, interface):
+ """Test that other interfaces are still converted to numpy."""
+ config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface=interface)
+ dev = qml.device("default.qubit")
+ processed = dev.setup_execution_config(config)
+ assert processed.convert_to_numpy
+
# pylint: disable=too-few-public-methods
class TestPreprocessing:
diff --git a/tests/gradients/core/test_pulse_gradient.py b/tests/gradients/core/test_pulse_gradient.py
index 5fd9bf34937..d0aba1f4582 100644
--- a/tests/gradients/core/test_pulse_gradient.py
+++ b/tests/gradients/core/test_pulse_gradient.py
@@ -1485,7 +1485,6 @@ def circuit(params):
assert qml.math.allclose(j[0], e, atol=tol, rtol=0.0)
jax.clear_caches()
- @pytest.mark.xfail
@pytest.mark.parametrize("num_split_times", [1, 2])
@pytest.mark.parametrize("time_interface", ["python", "numpy", "jax"])
def test_simple_qnode_jit(self, num_split_times, time_interface):
diff --git a/tests/param_shift_dev.py b/tests/param_shift_dev.py
index 12fd11eea16..7c7442161e2 100644
--- a/tests/param_shift_dev.py
+++ b/tests/param_shift_dev.py
@@ -34,6 +34,7 @@ def preprocess(self, execution_config=qml.devices.DefaultExecutionConfig):
execution_config, use_device_jacobian_product=True
)
program, config = super().preprocess(execution_config)
+ config = dataclasses.replace(config, convert_to_numpy=True)
program.add_transform(qml.transform(qml.gradients.param_shift.expand_transform))
return program, config
diff --git a/tests/templates/test_state_preparations/test_mottonen_state_prep.py b/tests/templates/test_state_preparations/test_mottonen_state_prep.py
index 1a9f0aa6219..885def86e60 100644
--- a/tests/templates/test_state_preparations/test_mottonen_state_prep.py
+++ b/tests/templates/test_state_preparations/test_mottonen_state_prep.py
@@ -417,11 +417,13 @@ def circuit(state):
@pytest.mark.jax
-@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.05)])
-def test_jacobians_with_and_without_jit_match(shots, atol, seed):
+def test_jacobians_with_and_without_jit_match(seed):
"""Test that the Jacobian of the circuit is the same with and without jit."""
import jax
+ shots = None
+ atol = 0.005
+
dev = qml.device("default.qubit", shots=shots, seed=seed)
dev_no_shots = qml.device("default.qubit", shots=None)
@@ -433,7 +435,7 @@ def circuit(coeffs):
circuit_ps = qml.QNode(circuit, dev, diff_method="parameter-shift")
circuit_exact = qml.QNode(circuit, dev_no_shots)
- params = jax.numpy.array([0.5, 0.5, 0.5, 0.5])
+ params = jax.numpy.array([0.5, 0.5, 0.5, 0.5], dtype=jax.numpy.float64)
jac_exact_fn = jax.jacobian(circuit_exact)
jac_fd_fn = jax.jacobian(circuit_fd)
jac_fd_fn_jit = jax.jit(jac_fd_fn)
diff --git a/tests/test_qnode.py b/tests/test_qnode.py
index d79d162d5a4..e5a92989e7d 100644
--- a/tests/test_qnode.py
+++ b/tests/test_qnode.py
@@ -613,12 +613,13 @@ def func(x, y):
assert tape.measurements == contents[3:]
@pytest.mark.jax
- def test_jit_counts_raises_error(self):
+ @pytest.mark.parametrize("dev_name", ("default.qubit", "reference.qubit"))
+ def test_jit_counts_raises_error(self, dev_name):
"""Test that returning counts in a quantum function with trainable parameters while
jitting raises an error."""
import jax
- dev = qml.device("default.qubit", wires=2, shots=5)
+ dev = qml.device(dev_name, wires=2, shots=5)
def circuit1(param):
qml.Hadamard(0)
@@ -632,7 +633,7 @@ def circuit1(param):
with pytest.raises(
NotImplementedError, match="The JAX-JIT interface doesn't support qml.counts."
):
- jitted_qnode1(0.123)
+ _ = jitted_qnode1(0.123)
# Test with qnode decorator syntax
@qml.qnode(dev)
diff --git a/tests/workflow/interfaces/execute/test_jax.py b/tests/workflow/interfaces/execute/test_jax.py
index 7a7c9468ca5..db056dcd0c0 100644
--- a/tests/workflow/interfaces/execute/test_jax.py
+++ b/tests/workflow/interfaces/execute/test_jax.py
@@ -693,10 +693,12 @@ def test_max_diff(self, tol):
def cost_fn(x):
ops = [qml.RX(x[0], 0), qml.RY(x[1], 1), qml.CNOT((0, 1))]
- tape1 = qml.tape.QuantumScript(ops, [qml.var(qml.PauliZ(0) @ qml.PauliX(1))])
+ tape1 = qml.tape.QuantumScript(
+ ops, [qml.var(qml.PauliZ(0) @ qml.PauliX(1))], shots=50000
+ )
ops2 = [qml.RX(x[0], 0), qml.RY(x[0], 1), qml.CNOT((0, 1))]
- tape2 = qml.tape.QuantumScript(ops2, [qml.probs(wires=1)])
+ tape2 = qml.tape.QuantumScript(ops2, [qml.probs(wires=1)], shots=50000)
result = execute([tape1, tape2], dev, diff_method=param_shift, max_diff=1)
return result[0] + result[1][0]
@@ -704,13 +706,13 @@ def cost_fn(x):
res = cost_fn(params)
x, y = params
expected = 0.5 * (3 + jnp.cos(x) ** 2 * jnp.cos(2 * y))
- assert np.allclose(res, expected, atol=tol, rtol=0)
+ assert np.allclose(res, expected, atol=2e-2, rtol=0)
res = jax.grad(cost_fn)(params)
expected = jnp.array(
[-jnp.cos(x) * jnp.cos(2 * y) * jnp.sin(x), -jnp.cos(x) ** 2 * jnp.sin(2 * y)]
)
- assert np.allclose(res, expected, atol=tol, rtol=0)
+ assert np.allclose(res, expected, atol=2e-2, rtol=0)
res = jax.jacobian(jax.grad(cost_fn))(params)
expected = jnp.zeros([2, 2])
diff --git a/tests/workflow/interfaces/execute/test_jax_jit.py b/tests/workflow/interfaces/execute/test_jax_jit.py
index ce6a9ef27f4..95b056cf476 100644
--- a/tests/workflow/interfaces/execute/test_jax_jit.py
+++ b/tests/workflow/interfaces/execute/test_jax_jit.py
@@ -886,14 +886,17 @@ def cost(x, y, device, interface, ek):
class TestJitAllCounts:
+ @pytest.mark.parametrize(
+ "device_name", (pytest.param("default.qubit", marks=pytest.mark.xfail), "reference.qubit")
+ )
@pytest.mark.parametrize("counts_wires", (None, (0, 1)))
- def test_jit_allcounts(self, counts_wires):
+ def test_jit_allcounts(self, device_name, counts_wires):
"""Test jitting with counts with all_outcomes == True."""
tape = qml.tape.QuantumScript(
[qml.RX(0, 0), qml.I(1)], [qml.counts(wires=counts_wires, all_outcomes=True)], shots=50
)
- device = qml.device("default.qubit")
+ device = qml.device(device_name, wires=2)
res = jax.jit(qml.execute, static_argnums=(1, 2))(
(tape,), device, qml.gradients.param_shift
@@ -904,7 +907,14 @@ def test_jit_allcounts(self, counts_wires):
for val in ["01", "10", "11"]:
assert qml.math.allclose(res[val], 0)
- def test_jit_allcounts_broadcasting(self):
+ @pytest.mark.parametrize(
+ "device_name",
+ (
+ pytest.param("default.qubit", marks=pytest.mark.xfail),
+ pytest.param("reference.qubit", marks=pytest.mark.xfail),
+ ),
+ )
+ def test_jit_allcounts_broadcasting(self, device_name):
"""Test jitting with counts with all_outcomes == True."""
tape = qml.tape.QuantumScript(
@@ -912,7 +922,7 @@ def test_jit_allcounts_broadcasting(self):
[qml.counts(wires=(0, 1), all_outcomes=True)],
shots=50,
)
- device = qml.device("default.qubit")
+ device = qml.device(device_name, wires=2)
res = jax.jit(qml.execute, static_argnums=(1, 2))(
(tape,), device, qml.gradients.param_shift
@@ -927,7 +937,6 @@ def test_jit_allcounts_broadcasting(self):
assert qml.math.allclose(ri[val], 0)
-@pytest.mark.xfail(reason="Need to figure out how to handle this case in a less ambiguous manner")
def test_diff_method_None_jit():
"""Test that jitted execution works when `diff_method=None`."""
diff --git a/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py b/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py
index 544da89f73d..862806999c6 100644
--- a/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py
+++ b/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py
@@ -3185,7 +3185,8 @@ def test_complex64_return(self, diff_method):
jax.config.update("jax_enable_x64", False)
try:
- tol = 2e-2 if diff_method == "finite-diff" else 1e-6
+ # finite diff with float32 ...
+ tol = 5e-2 if diff_method == "finite-diff" else 1e-6
@jax.jit
@qml.qnode(qml.device("default.qubit", wires=1), diff_method=diff_method)
diff --git a/tests/workflow/interfaces/test_jacobian_products.py b/tests/workflow/interfaces/test_jacobian_products.py
index 9a090d3af0a..8a3d6745686 100644
--- a/tests/workflow/interfaces/test_jacobian_products.py
+++ b/tests/workflow/interfaces/test_jacobian_products.py
@@ -136,7 +136,7 @@ def test_device_jacobians_repr(self):
r" use_device_jacobian_product=None,"
r" gradient_method='adjoint', gradient_keyword_arguments={},"
r" device_options={}, interface=, derivative_order=1,"
- r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>"
+ r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True)>"
)
assert repr(jpc) == expected
@@ -155,7 +155,7 @@ def test_device_jacobian_products_repr(self):
r" use_device_jacobian_product=None,"
r" gradient_method='adjoint', gradient_keyword_arguments={}, device_options={},"
r" interface=, derivative_order=1,"
- r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>"
+ r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True)>"
)
assert repr(jpc) == expected
diff --git a/tests/workflow/test_setup_transform_program.py b/tests/workflow/test_setup_transform_program.py
index c81ed75ae2f..79207491014 100644
--- a/tests/workflow/test_setup_transform_program.py
+++ b/tests/workflow/test_setup_transform_program.py
@@ -140,9 +140,7 @@ def test_prune_dynamic_transform_warning_raised():
def test_interface_data_not_supported():
"""Test that convert_to_numpy_parameters transform is correctly added."""
- config = ExecutionConfig()
- config.interface = "autograd"
- config.gradient_method = "adjoint"
+ config = ExecutionConfig(interface="autograd", gradient_method="adjoint")
device = qml.device("default.qubit")
user_transform_program = TransformProgram()
@@ -154,10 +152,8 @@ def test_interface_data_not_supported():
def test_interface_data_supported():
"""Test that convert_to_numpy_parameters transform is not added for these cases."""
- config = ExecutionConfig()
+ config = ExecutionConfig(interface="autograd", gradient_method=None)
- config.interface = "autograd"
- config.gradient_method = None
device = qml.device("default.mixed", wires=1)
user_transform_program = TransformProgram()
@@ -165,10 +161,8 @@ def test_interface_data_supported():
assert qml.transforms.convert_to_numpy_parameters not in inner_tp
- config = ExecutionConfig()
+ config = ExecutionConfig(interface="autograd", gradient_method="backprop")
- config.interface = "autograd"
- config.gradient_method = "backprop"
device = qml.device("default.qubit")
user_transform_program = TransformProgram()
@@ -176,10 +170,8 @@ def test_interface_data_supported():
assert qml.transforms.convert_to_numpy_parameters not in inner_tp
- config = ExecutionConfig()
+ config = ExecutionConfig(interface=None, gradient_method="backprop")
- config.interface = None
- config.gradient_method = "backprop"
device = qml.device("default.qubit")
user_transform_program = TransformProgram()
@@ -187,6 +179,13 @@ def test_interface_data_supported():
assert qml.transforms.convert_to_numpy_parameters not in inner_tp
+ config = ExecutionConfig(
+ convert_to_numpy=False, interface="jax", gradient_method=qml.gradients.param_shift
+ )
+
+ _, inner_tp = _setup_transform_program(TransformProgram(), device, config)
+ assert qml.transforms.convert_to_numpy_parameters not in inner_tp
+
def test_cache_handling():
"""Test that caching is handled correctly."""