Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX pulse simulation compilation bug #122

Closed
DanPuzzuoli opened this issue Aug 26, 2022 · 2 comments · Fixed by #125
Closed

JAX pulse simulation compilation bug #122

DanPuzzuoli opened this issue Aug 26, 2022 · 2 comments · Fixed by #125
Labels
bug Something isn't working

Comments

@DanPuzzuoli
Copy link
Collaborator

DanPuzzuoli commented Aug 26, 2022

Informations

  • Qiskit Dynamics version: 0.3.0
  • Python version: 3.10
  • Operating system: Mac OS X

What is the current behavior?

The automatic jit compilation routine for simulating pulse schedules with JAX raises errors under certain conditions.

Steps to reproduce the problem

Specifying if using method='jax_odeint', specifying the t_eval kwarg will cause the error to be raised. Also, if using method='jax_expm', the error will be raised regardless of whether t_eval is included or not.

Examples here can probably be reduced further in terms of lines of code, but here is a minimal example:

import numpy as np
from qiskit import pulse

from qiskit_dynamics.array import Array
from qiskit_dynamics import Solver

import jax
jax.config.update("jax_enable_x64", True)

# tell JAX we are using CPU
jax.config.update('jax_platform_name', 'cpu')

from qiskit_dynamics.array import Array
Array.set_default_backend('jax')

from qiskit_dynamics import Solver

Z = np.array([[1., 0.], [0., -1.]])
X = np.array([[0., 1.], [1., 0.]])

v = 5.
anharm = -0.33
r = 0.02
dt = 0.222222222222222 # dt for pulse schedules

# static part
static_hamiltonian = 2 * np.pi * v * Z
# drive term
drive_hamiltonian = 2 * np.pi * r * X

solver = Solver(
    static_hamiltonian=static_hamiltonian,
    hamiltonian_operators=[drive_hamiltonian],
    rotating_frame=static_hamiltonian,
    hamiltonian_channels=['d0'],
    channel_carrier_freqs={'d0': v},
    dt=dt
)

with pulse.build(name='x') as schedule:
    pulse.set_frequency(v, channel=pulse.DriveChannel(0))
    pulse.play(pulse.Drag(300, 1., 100, 0.), pulse.DriveChannel(0))

This works:

solver.solve(
    signals=schedule,
    y0=np.eye(2, dtype=complex),
    t_span=[0, 300 * dt],
    method='jax_odeint',
    atol=1e-10,
    rtol=1e-10
)

However this does not:

solver.solve(
    signals=schedule,
    y0=np.eye(2, dtype=complex),
    t_span=[0, 300 * dt],
    t_eval=[0, 150 * dt, 300 * dt],
    method='jax_odeint',
    atol=1e-10,
    rtol=1e-10
)

and similarly this does not:

solver.solve(
    signals=schedule,
    y0=np.eye(2, dtype=complex),
    t_span=[0, 300 * dt],
    t_eval=[0, 150 * dt, 300 * dt],
    method='jax_expm',
    max_dt=dt
)

What is the expected behavior?

The above calls should work and produce correct results.

Suggested solutions

Both errors are being raised in different places, but the underlying reason is the same: the function being internally compiled (located here) takes t_span as an argument, but in both cases (if t_eval is specified, or if a fixed step solver is being used), t_span cannot be compiled over.

Some imperfect options:

  • Set t_span as a static argument in the internally compiled function. This is an extremely easy option, but will limit the usefulness of the compilation to cases where all t_span values are the same. It will still apply to a lot of cases, but I think this is an option of last resort.
  • Figure out how to make both of the failure cases compilable. I think the first one is possible - the only reason it wasn't written to be compilable originally was due to difficulty, but I think in principle it can be done. This is an imperfect option however as I seem to remember there being fundamental issues with making the second one compilable (will need to see).
@DanPuzzuoli DanPuzzuoli added the bug Something isn't working label Aug 26, 2022
@DanPuzzuoli
Copy link
Collaborator Author

Upon further inspection I think it may be possible to fix both failure modes to be JAX-compilable. Will do further testing of this idea then implement in a PR if it works out.

@DanPuzzuoli
Copy link
Collaborator Author

So it seems like:

  • For method='jax_odeint', the internal utilities for handling t_span and t_eval can be updated in a fairly straightforward way so that we can compile over t_span and t_eval.
  • diffrax methods already work fine for this (after a minor tweak).
  • For the JAX fixed-step solvers (built in dynamics), making the internal logic compilable w.r.t. t_span and t_eval is going to be pretty non-trivial. The main problem is that the number of iterations depends on the values of t_span and t_eval. One simple way to make it compilable is to use JAX while loops, but this will break the ability to reverse-mode differentiate these solvers, so this doesn't seem like an option. I think the only way to make them compilable AND reverse-mode differentiable is probably to use while loops AND write a custom vjp rule (this is what's done, e.g. in jax_odeint).

For the last point, I don't think it's worth spending the time to do this personally right now - it will require a non-trivial rewrite of the fixed step solvers template, and (I'm guessing) these solvers are unlikely to be used in the majority of use-cases for pulse simulation. So, I think a reasonable course of action is:

  • Modify the check in Solver.solve that triggers the auto-JAX compilation pathway so that it only happens for jax_odeint and diffrax solvers (thereby avoiding this issue for the others).
  • I'll also make another github issue for updating the fixed-step solvers so they can be compiled over t_span and t_eval while preserving reverse-mode auto-differentiation. It can be an open issue that someone solves at some point if it ever becomes necessary.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant