Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow device to configure conversion to numpy and use of pure_callback #6788

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
9 changes: 7 additions & 2 deletions doc/development/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -470,12 +470,15 @@ pieces of functionality:
Note that these properties are only applicable to devices that provided derivatives or VJPs. If your device
does not provide derivatives, you can safely ignore these properties.

The workflow options are ``use_device_gradient``, ``use_device_jacobian_product``, and ``grad_on_execution``.
The workflow options are ``use_device_gradient``, ``use_device_jacobian_product``, ``grad_on_execution``,
and ``convert_to_numpy``.
``use_device_gradient=True`` indicates that workflow should request derivatives from the device.
``grad_on_execution=True`` indicates a preference to use ``execute_and_compute_derivatives`` instead
of ``execute`` followed by ``compute_derivatives``. Finally, ``use_device_jacobian_product`` indicates
a request to call ``compute_vjp`` instead of ``compute_derivatives``. Note that if ``use_device_jacobian_product``
is ``True``, this takes precedence over calculating the full jacobian.
is ``True``, this takes precedence over calculating the full jacobian. If can accept ML framework parameters, like
jax, they should specify ``convert_to_numpy=False``. Then the parameters will not be converted, and no ``jax.pure_callback``
will be used when jitting.

>>> config = qml.devices.ExecutionConfig(gradient_method="adjoint")
>>> processed_config = qml.device('default.qubit').setup_execution_config(config)
Expand All @@ -485,6 +488,8 @@ True
True
>>> processed_config.grad_on_execution
True
>>> processed_config.convert_to_numpy
True

Execution
---------
Expand Down
6 changes: 6 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@

<h3>Improvements 🛠</h3>

* Finite shot and parameter-shift executions on `default.qubit` can now
be natively jitted end-to-end, leading to performance improvements.
Devices can now configure whether or not ML framework data is sent to them
via a `ExecutionConfig.convert_to_numpy` parameter.
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)

<h3>Breaking changes 💔</h3>

<h3>Deprecations 👋</h3>
Expand Down
4 changes: 2 additions & 2 deletions pennylane/devices/default_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,8 +1008,8 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio
"best",
}
updated_values["grad_on_execution"] = False
if not execution_config.gradient_method in {"best", "backprop", None}:
execution_config.interface = None
if execution_config.gradient_method not in {"best", "backprop", None}:
updated_values["interface"] = None

# Add device options
updated_values["device_options"] = dict(execution_config.device_options) # copy
Expand Down
18 changes: 17 additions & 1 deletion pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,14 @@ def _conditional_broastcast_expand(tape):
return (tape,), null_postprocessing


@qml.transform
def no_counts(tape):
"""Throws an error on counts measurements."""
if any(isinstance(mp, qml.measurements.CountsMP) for mp in tape.measurements):
raise NotImplementedError("The JAX-JIT interface doesn't support qml.counts.")
astralcai marked this conversation as resolved.
Show resolved Hide resolved
return (tape,), null_postprocessing


@qml.transform
def adjoint_state_measurements(
tape: QuantumScript, device_vjp=False
Expand Down Expand Up @@ -535,6 +543,8 @@ def preprocess(
config = self._setup_execution_config(execution_config)
transform_program = TransformProgram()

if config.interface == qml.math.Interface.JAX_JIT:
transform_program.add_transform(no_counts)
transform_program.add_transform(validate_device_wires, self.wires, name=self.name)
transform_program.add_transform(
mid_circuit_measurements, device=self, mcm_config=config.mcm_config
Expand Down Expand Up @@ -581,6 +591,13 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio
"""
updated_values = {}

jax_interfaces = {qml.math.Interface.JAX, qml.math.Interface.JAX_JIT}
updated_values["convert_to_numpy"] = (
execution_config.interface not in jax_interfaces
or execution_config.gradient_method == "adjoint"
astralcai marked this conversation as resolved.
Show resolved Hide resolved
# need numpy to use caching, and need caching higher order derivatives
or execution_config.derivative_order > 1
)
for option in execution_config.device_options:
if option not in self._device_options:
raise qml.DeviceError(f"device option {option} not present on {self}")
Expand Down Expand Up @@ -616,7 +633,6 @@ def execute(
execution_config: ExecutionConfig = DefaultExecutionConfig,
) -> Union[Result, ResultBatch]:
self.reset_prng_key()

max_workers = execution_config.device_options.get("max_workers", self._max_workers)
self._state_cache = {} if execution_config.use_device_jacobian_product else None
interface = (
Expand Down
13 changes: 10 additions & 3 deletions pennylane/devices/execution_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from typing import Optional, Union

from pennylane.math import get_canonical_interface_name
from pennylane.math import Interface, get_canonical_interface_name
from pennylane.transforms.core import TransformDispatcher


Expand Down Expand Up @@ -87,7 +87,7 @@ class ExecutionConfig:
device_options: Optional[dict] = None
"""Various options for the device executing a quantum circuit"""

interface: Optional[str] = None
interface: Interface = Interface.NUMPY
"""The machine learning framework to use"""

derivative_order: int = 1
Expand All @@ -96,6 +96,13 @@ class ExecutionConfig:
mcm_config: MCMConfig = field(default_factory=MCMConfig)
"""Configuration options for handling mid-circuit measurements"""

convert_to_numpy: bool = True
"""Whether or not to convert parameters to numpy before execution.

If ``False`` and using the jax-jit, no pure callback will occur and the device
execution itself will be jitted.
"""

def __post_init__(self):
"""
Validate the configured execution options.
Expand Down Expand Up @@ -124,7 +131,7 @@ def __post_init__(self):
)

if isinstance(self.mcm_config, dict):
self.mcm_config = MCMConfig(**self.mcm_config)
self.mcm_config = MCMConfig(**self.mcm_config) # pylint: disable=not-a-mapping

elif not isinstance(self.mcm_config, MCMConfig):
raise ValueError(f"Got invalid type {type(self.mcm_config)} for 'mcm_config'")
Expand Down
4 changes: 2 additions & 2 deletions pennylane/devices/qubit/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,6 @@ def _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key=None,
_, key = jax_random_split(prng_key)
samples = jax.random.choice(key, basis_states, shape=(shots,), p=probs)

powers_of_two = 1 << jnp.arange(num_wires, dtype=jnp.int64)[::-1]
powers_of_two = 1 << jnp.arange(num_wires, dtype=int)[::-1]
states_sampled_base_ten = samples[..., None] & powers_of_two
return (states_sampled_base_ten > 0).astype(jnp.int64)
return (states_sampled_base_ten > 0).astype(int)
3 changes: 2 additions & 1 deletion pennylane/workflow/_setup_transform_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def _setup_transform_program(

# changing this set of conditions causes a bunch of tests to break.
interface_data_supported = (
resolved_execution_config.interface is Interface.NUMPY
(not resolved_execution_config.convert_to_numpy)
or resolved_execution_config.interface is Interface.NUMPY
or resolved_execution_config.gradient_method == "backprop"
or (
getattr(device, "short_name", "") == "default.mixed"
Expand Down
2 changes: 1 addition & 1 deletion pennylane/workflow/interfaces/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _to_jax(result: qml.typing.ResultBatch) -> qml.typing.ResultBatch:
return result
if isinstance(result, (list, tuple)):
return tuple(_to_jax(r) for r in result)
return jnp.array(result)
return result if qml.math.get_interface(result) == "jax" else jnp.array(result)


def _execute_wrapper(params, tapes, execute_fn, jpc) -> ResultBatch:
Expand Down
2 changes: 1 addition & 1 deletion pennylane/workflow/interfaces/jax_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _to_jax(result: qml.typing.ResultBatch) -> qml.typing.ResultBatch:

"""
if isinstance(result, dict):
return {key: jnp.array(value) for key, value in result.items()}
return {key: _to_jax(value) for key, value in result.items()}
if isinstance(result, (list, tuple)):
return tuple(_to_jax(r) for r in result)
return jnp.array(result)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/workflow/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _get_ml_boundary_execute(
elif interface == Interface.TORCH:
from .interfaces.torch import execute as ml_boundary

elif interface == Interface.JAX_JIT:
elif interface == Interface.JAX_JIT and resolved_execution_config.convert_to_numpy:
from .interfaces.jax_jit import jax_jit_jvp_execute as ml_boundary

else: # interface is jax
Expand Down
6 changes: 1 addition & 5 deletions tests/devices/default_qubit/test_default_qubit_native_mcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,11 +389,7 @@ def func(x, y, z):
results1 = func1(*params)

jaxpr = str(jax.make_jaxpr(func)(*params))
if diff_method == "best":
assert "pure_callback" in jaxpr
pytest.xfail("QNode with diff_method='best' cannot be compiled with jax.jit.")
else:
assert "pure_callback" not in jaxpr
assert "pure_callback" not in jaxpr

func2 = jax.jit(func)
results2 = func2(*params)
Expand Down
26 changes: 26 additions & 0 deletions tests/devices/default_qubit/test_default_qubit_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,32 @@ def circuit(x):

assert dev.tracker.totals["execute_and_derivative_batches"] == 1

@pytest.mark.parametrize("interface", ("jax", "jax-jit"))
def test_not_convert_to_numpy_with_jax(self, interface):
"""Test that we will not convert to numpy when working with jax."""

dev = qml.device("default.qubit")
config = qml.devices.ExecutionConfig(
gradient_method=qml.gradients.param_shift, interface=interface
)
processed = dev.setup_execution_config(config)
assert not processed.convert_to_numpy

def test_convert_to_numpy_with_adjoint(self):
"""Test that we will convert to numpy with adjoint."""
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface="jax-jit")
dev = qml.device("default.qubit")
processed = dev.setup_execution_config(config)
assert processed.convert_to_numpy

@pytest.mark.parametrize("interface", ("autograd", "torch", "tf"))
def test_convert_to_numpy_non_jax(self, interface):
"""Test that other interfaces are still converted to numpy."""
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface=interface)
dev = qml.device("default.qubit")
processed = dev.setup_execution_config(config)
assert processed.convert_to_numpy
Comment on lines +155 to +168
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you include "jax" as an interface in the testing?

Also, I'm curious if converting to numpy with adjoint has negative effects on performance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think we could make adjoint jittable and get some nice speed boosts, but i think that would need to be a follow on task.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the jax interface is tested in test_not_convert_to_numpy_with_jax.



# pylint: disable=too-few-public-methods
class TestPreprocessing:
Expand Down
1 change: 0 additions & 1 deletion tests/gradients/core/test_pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,7 +1485,6 @@ def circuit(params):
assert qml.math.allclose(j[0], e, atol=tol, rtol=0.0)
jax.clear_caches()

@pytest.mark.xfail
@pytest.mark.parametrize("num_split_times", [1, 2])
@pytest.mark.parametrize("time_interface", ["python", "numpy", "jax"])
def test_simple_qnode_jit(self, num_split_times, time_interface):
Expand Down
1 change: 1 addition & 0 deletions tests/param_shift_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def preprocess(self, execution_config=qml.devices.DefaultExecutionConfig):
execution_config, use_device_jacobian_product=True
)
program, config = super().preprocess(execution_config)
config = dataclasses.replace(config, convert_to_numpy=True)
program.add_transform(qml.transform(qml.gradients.param_shift.expand_transform))
return program, config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,11 +417,13 @@ def circuit(state):


@pytest.mark.jax
@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.05)])
def test_jacobians_with_and_without_jit_match(shots, atol, seed):
def test_jacobians_with_and_without_jit_match(seed):
"""Test that the Jacobian of the circuit is the same with and without jit."""
import jax

shots = None
atol = 0.005

dev = qml.device("default.qubit", shots=shots, seed=seed)
dev_no_shots = qml.device("default.qubit", shots=None)

Expand All @@ -433,7 +435,7 @@ def circuit(coeffs):
circuit_ps = qml.QNode(circuit, dev, diff_method="parameter-shift")
circuit_exact = qml.QNode(circuit, dev_no_shots)

params = jax.numpy.array([0.5, 0.5, 0.5, 0.5])
params = jax.numpy.array([0.5, 0.5, 0.5, 0.5], dtype=jax.numpy.float64)
jac_exact_fn = jax.jacobian(circuit_exact)
jac_fd_fn = jax.jacobian(circuit_fd)
jac_fd_fn_jit = jax.jit(jac_fd_fn)
Expand Down
7 changes: 4 additions & 3 deletions tests/test_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,12 +687,13 @@ def func(x, y):
assert tape.measurements == contents[3:]

@pytest.mark.jax
def test_jit_counts_raises_error(self):
@pytest.mark.parametrize("dev_name", ("default.qubit", "reference.qubit"))
def test_jit_counts_raises_error(self, dev_name):
"""Test that returning counts in a quantum function with trainable parameters while
jitting raises an error."""
import jax

dev = qml.device("default.qubit", wires=2, shots=5)
dev = qml.device(dev_name, wires=2, shots=5)

def circuit1(param):
qml.Hadamard(0)
Expand All @@ -706,7 +707,7 @@ def circuit1(param):
with pytest.raises(
NotImplementedError, match="The JAX-JIT interface doesn't support qml.counts."
):
jitted_qnode1(0.123)
_ = jitted_qnode1(0.123)

# Test with qnode decorator syntax
@qml.qnode(dev)
Expand Down
10 changes: 6 additions & 4 deletions tests/workflow/interfaces/execute/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,24 +693,26 @@ def test_max_diff(self, tol):

def cost_fn(x):
ops = [qml.RX(x[0], 0), qml.RY(x[1], 1), qml.CNOT((0, 1))]
tape1 = qml.tape.QuantumScript(ops, [qml.var(qml.PauliZ(0) @ qml.PauliX(1))])
tape1 = qml.tape.QuantumScript(
ops, [qml.var(qml.PauliZ(0) @ qml.PauliX(1))], shots=50000
)

ops2 = [qml.RX(x[0], 0), qml.RY(x[0], 1), qml.CNOT((0, 1))]
tape2 = qml.tape.QuantumScript(ops2, [qml.probs(wires=1)])
tape2 = qml.tape.QuantumScript(ops2, [qml.probs(wires=1)], shots=50000)

result = execute([tape1, tape2], dev, diff_method=param_shift, max_diff=1)
return result[0] + result[1][0]

res = cost_fn(params)
x, y = params
expected = 0.5 * (3 + jnp.cos(x) ** 2 * jnp.cos(2 * y))
assert np.allclose(res, expected, atol=tol, rtol=0)
assert np.allclose(res, expected, atol=2e-2, rtol=0)
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved

res = jax.grad(cost_fn)(params)
expected = jnp.array(
[-jnp.cos(x) * jnp.cos(2 * y) * jnp.sin(x), -jnp.cos(x) ** 2 * jnp.sin(2 * y)]
)
assert np.allclose(res, expected, atol=tol, rtol=0)
assert np.allclose(res, expected, atol=2e-2, rtol=0)

res = jax.jacobian(jax.grad(cost_fn))(params)
expected = jnp.zeros([2, 2])
Expand Down
19 changes: 14 additions & 5 deletions tests/workflow/interfaces/execute/test_jax_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,14 +886,17 @@ def cost(x, y, device, interface, ek):

class TestJitAllCounts:

@pytest.mark.parametrize(
"device_name", (pytest.param("default.qubit", marks=pytest.mark.xfail), "reference.qubit")
)
@pytest.mark.parametrize("counts_wires", (None, (0, 1)))
def test_jit_allcounts(self, counts_wires):
def test_jit_allcounts(self, device_name, counts_wires):
"""Test jitting with counts with all_outcomes == True."""

tape = qml.tape.QuantumScript(
[qml.RX(0, 0), qml.I(1)], [qml.counts(wires=counts_wires, all_outcomes=True)], shots=50
)
device = qml.device("default.qubit")
device = qml.device(device_name, wires=2)

res = jax.jit(qml.execute, static_argnums=(1, 2))(
(tape,), device, qml.gradients.param_shift
Expand All @@ -904,15 +907,22 @@ def test_jit_allcounts(self, counts_wires):
for val in ["01", "10", "11"]:
assert qml.math.allclose(res[val], 0)

def test_jit_allcounts_broadcasting(self):
@pytest.mark.parametrize(
"device_name",
(
pytest.param("default.qubit", marks=pytest.mark.xfail),
pytest.param("reference.qubit", marks=pytest.mark.xfail),
),
)
def test_jit_allcounts_broadcasting(self, device_name):
"""Test jitting with counts with all_outcomes == True."""

tape = qml.tape.QuantumScript(
[qml.RX(np.array([0.0, 0.0]), 0)],
[qml.counts(wires=(0, 1), all_outcomes=True)],
shots=50,
)
device = qml.device("default.qubit")
device = qml.device(device_name, wires=2)

res = jax.jit(qml.execute, static_argnums=(1, 2))(
(tape,), device, qml.gradients.param_shift
Expand All @@ -927,7 +937,6 @@ def test_jit_allcounts_broadcasting(self):
assert qml.math.allclose(ri[val], 0)


@pytest.mark.xfail(reason="Need to figure out how to handle this case in a less ambiguous manner")
def test_diff_method_None_jit():
"""Test that jitted execution works when `diff_method=None`."""

Expand Down
3 changes: 2 additions & 1 deletion tests/workflow/interfaces/qnode/test_jax_jit_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3185,7 +3185,8 @@ def test_complex64_return(self, diff_method):
jax.config.update("jax_enable_x64", False)

try:
tol = 2e-2 if diff_method == "finite-diff" else 1e-6
# finite diff with float32 ...
tol = 5e-2 if diff_method == "finite-diff" else 1e-6

@jax.jit
@qml.qnode(qml.device("default.qubit", wires=1), diff_method=diff_method)
Expand Down
Loading
Loading