-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sampling breaks on JAX 0.2.0 #36
Comments
Raised an issue on the JAX repo: jax-ml/jax#4416 (comment) |
Just to follow up this issue — the JAX team responded on the thread, in short saying it’s a deliberate change:
and then as for what to do (emphasis added):
and they’ve closed the issue (awaiting any further questions) |
Thanks @lmmx! I did not respond yet as I did not have time to look into it and did not want to pollute their issue tracker. Will reopen if needed when I get to that. In the meantime I pinned the version to the previous version so I can keep testing stuff. |
I think that the problem here stems from the fact that If that is the issue the solution is simple, we have to pass @functools.partial(jax.jit, static_argnums=(1,2))
def update_chain(rng_key, kernel_factory, parameters, chain_state):
kernel = self.kernel_factory(*parameters)
new_chain_state, info = kernel(rng_key, chain_state)
return new_chain_state, info And later in the @functools.partial(jax.jit, static_argnums=(2,3))
def update_loop(state, key, kernel_factory, parameters):
keys = jax.random.split(key, num_chains)
state, info = jax.vmap(kernel, in_axes=(0, None, 0, 0))(keys, kernel_factory, parameters, state)
return state, info, mcx_ravel_pytree((state, info))[0] Which is slightly more verbose, but it would make sense to gather |
While everything runs fine on v0.1.77, running sampling with JAX 0.2.0 returns the following error:
The text was updated successfully, but these errors were encountered: