Skip to content

Commit

Permalink
Use jax.tree_map in place of jax.tree_multimap
Browse files Browse the repository at this point in the history
The latter is deprecated in jax and will be removed in jax-ml/jax#11382
  • Loading branch information
jakevdp committed Jul 20, 2022
1 parent 4a4e0a4 commit eedf442
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions blackjax/sgmcmc/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def one_step(
) -> SGHMCState:
position, momentum, logprob_grad = state
noise = generate_gaussian_noise(rng_key, position)
position = jax.tree_util.tree_multimap(lambda x, p: x + p, position, momentum)
momentum = jax.tree_util.tree_multimap(
position = jax.tree_util.tree_map(lambda x, p: x + p, position, momentum)
momentum = jax.tree_util.tree_map(
lambda p, g, n: (1.0 - alpha) * p
+ step_size * g
+ jnp.sqrt(2 * step_size * (alpha - beta)) * n,
Expand Down

0 comments on commit eedf442

Please sign in to comment.