Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow device to configure conversion to numpy and use of pure_callback #6788

Merged
merged 15 commits into from
Jan 15, 2025
Merged
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

<h3>Improvements 🛠</h3>

* Devices can now configure whether or not the data is converted to numpy and `jax.pure_callback`
is used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions
on `default.qubit` can now be jitted end-to-end, even with parameter shift.
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)

albi3ro marked this conversation as resolved.
Show resolved Hide resolved
<h3>Breaking changes 💔</h3>

<h3>Deprecations 👋</h3>
Expand Down
4 changes: 2 additions & 2 deletions pennylane/devices/default_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,8 +1008,8 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio
"best",
}
updated_values["grad_on_execution"] = False
if not execution_config.gradient_method in {"best", "backprop", None}:
execution_config.interface = None
if execution_config.gradient_method not in {"best", "backprop", None}:
updated_values["interface"] = None

# Add device options
updated_values["device_options"] = dict(execution_config.device_options) # copy
Expand Down
17 changes: 16 additions & 1 deletion pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,14 @@ def _conditional_broastcast_expand(tape):
return (tape,), null_postprocessing


@qml.transform
def no_counts(tape):
"""Throws an error on counts measurements."""
if any(isinstance(mp, qml.measurements.CountsMP) for mp in tape.measurements):
raise NotImplementedError("The JAX-JIT interface doesn't support qml.counts.")
astralcai marked this conversation as resolved.
Show resolved Hide resolved
return (tape,), null_postprocessing


@qml.transform
def adjoint_state_measurements(
tape: QuantumScript, device_vjp=False
Expand Down Expand Up @@ -535,6 +543,8 @@ def preprocess(
config = self._setup_execution_config(execution_config)
transform_program = TransformProgram()

if config.interface == qml.math.Interface.JAX_JIT:
transform_program.add_transform(no_counts)
transform_program.add_transform(validate_device_wires, self.wires, name=self.name)
transform_program.add_transform(
mid_circuit_measurements, device=self, mcm_config=config.mcm_config
Expand Down Expand Up @@ -581,6 +591,12 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio
"""
updated_values = {}

updated_values["convert_to_numpy"] = (
execution_config.interface.value not in {"jax", "jax-jit"}
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
or execution_config.gradient_method == "adjoint"
astralcai marked this conversation as resolved.
Show resolved Hide resolved
# need numpy to use caching, and need caching higher order derivatives
or execution_config.derivative_order > 1
)
for option in execution_config.device_options:
if option not in self._device_options:
raise qml.DeviceError(f"device option {option} not present on {self}")
Expand Down Expand Up @@ -616,7 +632,6 @@ def execute(
execution_config: ExecutionConfig = DefaultExecutionConfig,
) -> Union[Result, ResultBatch]:
self.reset_prng_key()

max_workers = execution_config.device_options.get("max_workers", self._max_workers)
self._state_cache = {} if execution_config.use_device_jacobian_product else None
interface = (
Expand Down
13 changes: 10 additions & 3 deletions pennylane/devices/execution_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from typing import Optional, Union

from pennylane.math import get_canonical_interface_name
from pennylane.math import Interface, get_canonical_interface_name
from pennylane.transforms.core import TransformDispatcher


Expand Down Expand Up @@ -87,7 +87,7 @@ class ExecutionConfig:
device_options: Optional[dict] = None
"""Various options for the device executing a quantum circuit"""

interface: Optional[str] = None
interface: Interface = Interface.NUMPY
"""The machine learning framework to use"""

derivative_order: int = 1
Expand All @@ -96,6 +96,13 @@ class ExecutionConfig:
mcm_config: MCMConfig = field(default_factory=MCMConfig)
"""Configuration options for handling mid-circuit measurements"""

convert_to_numpy: bool = True
"""Whether or not to convert parameters to numpy before execution.

If ``False`` and using the jax-jit, no pure callback will occur and the device
execution itself will be jitted.
"""

def __post_init__(self):
"""
Validate the configured execution options.
Expand Down Expand Up @@ -124,7 +131,7 @@ def __post_init__(self):
)

if isinstance(self.mcm_config, dict):
self.mcm_config = MCMConfig(**self.mcm_config)
self.mcm_config = MCMConfig(**self.mcm_config) # pylint: disable=not-a-mapping

elif not isinstance(self.mcm_config, MCMConfig):
raise ValueError(f"Got invalid type {type(self.mcm_config)} for 'mcm_config'")
Expand Down
4 changes: 2 additions & 2 deletions pennylane/devices/qubit/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,6 @@ def _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key=None,
_, key = jax_random_split(prng_key)
samples = jax.random.choice(key, basis_states, shape=(shots,), p=probs)

powers_of_two = 1 << jnp.arange(num_wires, dtype=jnp.int64)[::-1]
powers_of_two = 1 << jnp.arange(num_wires, dtype=int)[::-1]
states_sampled_base_ten = samples[..., None] & powers_of_two
return (states_sampled_base_ten > 0).astype(jnp.int64)
return (states_sampled_base_ten > 0).astype(int)
3 changes: 2 additions & 1 deletion pennylane/workflow/_setup_transform_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def _setup_transform_program(

# changing this set of conditions causes a bunch of tests to break.
interface_data_supported = (
resolved_execution_config.interface is Interface.NUMPY
(not resolved_execution_config.convert_to_numpy)
or resolved_execution_config.interface is Interface.NUMPY
or resolved_execution_config.gradient_method == "backprop"
or (
getattr(device, "short_name", "") == "default.mixed"
Expand Down
2 changes: 1 addition & 1 deletion pennylane/workflow/interfaces/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _to_jax(result: qml.typing.ResultBatch) -> qml.typing.ResultBatch:
return result
if isinstance(result, (list, tuple)):
return tuple(_to_jax(r) for r in result)
return jnp.array(result)
return result if qml.math.get_interface(result) == "jax" else jnp.array(result)


def _execute_wrapper(params, tapes, execute_fn, jpc) -> ResultBatch:
Expand Down
2 changes: 1 addition & 1 deletion pennylane/workflow/interfaces/jax_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _to_jax(result: qml.typing.ResultBatch) -> qml.typing.ResultBatch:

"""
if isinstance(result, dict):
return {key: jnp.array(value) for key, value in result.items()}
return {key: _to_jax(value) for key, value in result.items()}
if isinstance(result, (list, tuple)):
return tuple(_to_jax(r) for r in result)
return jnp.array(result)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/workflow/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _get_ml_boundary_execute(
elif interface == Interface.TORCH:
from .interfaces.torch import execute as ml_boundary

elif interface == Interface.JAX_JIT:
elif interface == Interface.JAX_JIT and resolved_execution_config.convert_to_numpy:
from .interfaces.jax_jit import jax_jit_jvp_execute as ml_boundary

else: # interface is jax
Expand Down
6 changes: 1 addition & 5 deletions tests/devices/default_qubit/test_default_qubit_native_mcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,11 +389,7 @@ def func(x, y, z):
results1 = func1(*params)

jaxpr = str(jax.make_jaxpr(func)(*params))
if diff_method == "best":
assert "pure_callback" in jaxpr
pytest.xfail("QNode with diff_method='best' cannot be compiled with jax.jit.")
else:
assert "pure_callback" not in jaxpr
assert "pure_callback" not in jaxpr

func2 = jax.jit(func)
results2 = func2(*params)
Expand Down
26 changes: 26 additions & 0 deletions tests/devices/default_qubit/test_default_qubit_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,32 @@ def circuit(x):

assert dev.tracker.totals["execute_and_derivative_batches"] == 1

@pytest.mark.parametrize("interface", ("jax", "jax-jit"))
def test_not_convert_to_numpy_with_jax(self, interface):
"""Test that we will not convert to numpy when working with jax."""

dev = qml.device("default.qubit")
config = qml.devices.ExecutionConfig(
gradient_method=qml.gradients.param_shift, interface=interface
)
processed = dev.setup_execution_config(config)
assert not processed.convert_to_numpy

def test_convert_to_numpy_with_adjoint(self):
"""Test that we will convert to numpy with adjoint."""
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface="jax-jit")
dev = qml.device("default.qubit")
processed = dev.setup_execution_config(config)
assert processed.convert_to_numpy

@pytest.mark.parametrize("interface", ("autograd", "torch", "tf"))
def test_convert_to_numpy_non_jax(self, interface):
"""Test that other interfaces are still converted to numpy."""
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface=interface)
dev = qml.device("default.qubit")
processed = dev.setup_execution_config(config)
assert processed.convert_to_numpy
albi3ro marked this conversation as resolved.
Show resolved Hide resolved


# pylint: disable=too-few-public-methods
class TestPreprocessing:
Expand Down
1 change: 0 additions & 1 deletion tests/gradients/core/test_pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,7 +1485,6 @@ def circuit(params):
assert qml.math.allclose(j[0], e, atol=tol, rtol=0.0)
jax.clear_caches()

@pytest.mark.xfail
@pytest.mark.parametrize("num_split_times", [1, 2])
@pytest.mark.parametrize("time_interface", ["python", "numpy", "jax"])
def test_simple_qnode_jit(self, num_split_times, time_interface):
Expand Down
1 change: 1 addition & 0 deletions tests/param_shift_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def preprocess(self, execution_config=qml.devices.DefaultExecutionConfig):
execution_config, use_device_jacobian_product=True
)
program, config = super().preprocess(execution_config)
config = dataclasses.replace(config, convert_to_numpy=True)
program.add_transform(qml.transform(qml.gradients.param_shift.expand_transform))
return program, config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def circuit(state):


@pytest.mark.jax
@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.05)])
@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.1)])
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
def test_jacobians_with_and_without_jit_match(shots, atol, seed):
"""Test that the Jacobian of the circuit is the same with and without jit."""
import jax
Expand All @@ -433,7 +433,7 @@ def circuit(coeffs):
circuit_ps = qml.QNode(circuit, dev, diff_method="parameter-shift")
circuit_exact = qml.QNode(circuit, dev_no_shots)

params = jax.numpy.array([0.5, 0.5, 0.5, 0.5])
params = jax.numpy.array([0.5, 0.5, 0.5, 0.5], dtype=jax.numpy.float64)
jac_exact_fn = jax.jacobian(circuit_exact)
jac_fd_fn = jax.jacobian(circuit_fd)
jac_fd_fn_jit = jax.jit(jac_fd_fn)
Expand Down
8 changes: 5 additions & 3 deletions tests/test_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,12 +687,13 @@ def func(x, y):
assert tape.measurements == contents[3:]

@pytest.mark.jax
def test_jit_counts_raises_error(self):
@pytest.mark.parametrize("dev_name", ("default.qubit", "reference.qubit"))
def test_jit_counts_raises_error(self, dev_name):
"""Test that returning counts in a quantum function with trainable parameters while
jitting raises an error."""
import jax

dev = qml.device("default.qubit", wires=2, shots=5)
dev = qml.device(dev_name, wires=2, shots=5)

def circuit1(param):
qml.Hadamard(0)
Expand All @@ -706,7 +707,8 @@ def circuit1(param):
with pytest.raises(
NotImplementedError, match="The JAX-JIT interface doesn't support qml.counts."
):
jitted_qnode1(0.123)
out = jitted_qnode1(0.123)
print(out)
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

# Test with qnode decorator syntax
@qml.qnode(dev)
Expand Down
10 changes: 6 additions & 4 deletions tests/workflow/interfaces/execute/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,24 +693,26 @@ def test_max_diff(self, tol):

def cost_fn(x):
ops = [qml.RX(x[0], 0), qml.RY(x[1], 1), qml.CNOT((0, 1))]
tape1 = qml.tape.QuantumScript(ops, [qml.var(qml.PauliZ(0) @ qml.PauliX(1))])
tape1 = qml.tape.QuantumScript(
ops, [qml.var(qml.PauliZ(0) @ qml.PauliX(1))], shots=50000
)

ops2 = [qml.RX(x[0], 0), qml.RY(x[0], 1), qml.CNOT((0, 1))]
tape2 = qml.tape.QuantumScript(ops2, [qml.probs(wires=1)])
tape2 = qml.tape.QuantumScript(ops2, [qml.probs(wires=1)], shots=50000)

result = execute([tape1, tape2], dev, diff_method=param_shift, max_diff=1)
return result[0] + result[1][0]

res = cost_fn(params)
x, y = params
expected = 0.5 * (3 + jnp.cos(x) ** 2 * jnp.cos(2 * y))
assert np.allclose(res, expected, atol=tol, rtol=0)
assert np.allclose(res, expected, atol=2e-2, rtol=0)
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved

res = jax.grad(cost_fn)(params)
expected = jnp.array(
[-jnp.cos(x) * jnp.cos(2 * y) * jnp.sin(x), -jnp.cos(x) ** 2 * jnp.sin(2 * y)]
)
assert np.allclose(res, expected, atol=tol, rtol=0)
assert np.allclose(res, expected, atol=2e-2, rtol=0)

res = jax.jacobian(jax.grad(cost_fn))(params)
expected = jnp.zeros([2, 2])
Expand Down
19 changes: 14 additions & 5 deletions tests/workflow/interfaces/execute/test_jax_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,14 +886,17 @@ def cost(x, y, device, interface, ek):

class TestJitAllCounts:

@pytest.mark.parametrize(
"device_name", (pytest.param("default.qubit", marks=pytest.mark.xfail), "reference.qubit")
)
@pytest.mark.parametrize("counts_wires", (None, (0, 1)))
def test_jit_allcounts(self, counts_wires):
def test_jit_allcounts(self, device_name, counts_wires):
"""Test jitting with counts with all_outcomes == True."""

tape = qml.tape.QuantumScript(
[qml.RX(0, 0), qml.I(1)], [qml.counts(wires=counts_wires, all_outcomes=True)], shots=50
)
device = qml.device("default.qubit")
device = qml.device(device_name, wires=2)

res = jax.jit(qml.execute, static_argnums=(1, 2))(
(tape,), device, qml.gradients.param_shift
Expand All @@ -904,15 +907,22 @@ def test_jit_allcounts(self, counts_wires):
for val in ["01", "10", "11"]:
assert qml.math.allclose(res[val], 0)

def test_jit_allcounts_broadcasting(self):
@pytest.mark.parametrize(
"device_name",
(
pytest.param("default.qubit", marks=pytest.mark.xfail),
pytest.param("reference.qubit", marks=pytest.mark.xfail),
),
)
def test_jit_allcounts_broadcasting(self, device_name):
"""Test jitting with counts with all_outcomes == True."""

tape = qml.tape.QuantumScript(
[qml.RX(np.array([0.0, 0.0]), 0)],
[qml.counts(wires=(0, 1), all_outcomes=True)],
shots=50,
)
device = qml.device("default.qubit")
device = qml.device(device_name, wires=2)

res = jax.jit(qml.execute, static_argnums=(1, 2))(
(tape,), device, qml.gradients.param_shift
Expand All @@ -927,7 +937,6 @@ def test_jit_allcounts_broadcasting(self):
assert qml.math.allclose(ri[val], 0)


@pytest.mark.xfail(reason="Need to figure out how to handle this case in a less ambiguous manner")
def test_diff_method_None_jit():
"""Test that jitted execution works when `diff_method=None`."""

Expand Down
3 changes: 2 additions & 1 deletion tests/workflow/interfaces/qnode/test_jax_jit_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3185,7 +3185,8 @@ def test_complex64_return(self, diff_method):
jax.config.update("jax_enable_x64", False)

try:
tol = 2e-2 if diff_method == "finite-diff" else 1e-6
# finite diff with float32 ...
tol = 5e-2 if diff_method == "finite-diff" else 1e-6

@jax.jit
@qml.qnode(qml.device("default.qubit", wires=1), diff_method=diff_method)
Expand Down
4 changes: 2 additions & 2 deletions tests/workflow/interfaces/test_jacobian_products.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_device_jacobians_repr(self):
r" use_device_jacobian_product=None,"
r" gradient_method='adjoint', gradient_keyword_arguments={},"
r" device_options={}, interface=<Interface.NUMPY: 'numpy'>, derivative_order=1,"
r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>"
r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True)>"
)

assert repr(jpc) == expected
Expand All @@ -155,7 +155,7 @@ def test_device_jacobian_products_repr(self):
r" use_device_jacobian_product=None,"
r" gradient_method='adjoint', gradient_keyword_arguments={}, device_options={},"
r" interface=<Interface.NUMPY: 'numpy'>, derivative_order=1,"
r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>"
r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True)>"
)

assert repr(jpc) == expected
Expand Down
Loading
Loading