diff --git a/blackjax/sgmcmc/diffusion.py b/blackjax/sgmcmc/diffusion.py index 5495ad328..da04cce48 100644 --- a/blackjax/sgmcmc/diffusion.py +++ b/blackjax/sgmcmc/diffusion.py @@ -63,8 +63,8 @@ def one_step( ) -> SGHMCState: position, momentum, logprob_grad = state noise = generate_gaussian_noise(rng_key, position) - position = jax.tree_util.tree_multimap(lambda x, p: x + p, position, momentum) - momentum = jax.tree_util.tree_multimap( + position = jax.tree_util.tree_map(lambda x, p: x + p, position, momentum) + momentum = jax.tree_util.tree_map( lambda p, g, n: (1.0 - alpha) * p + step_size * g + jnp.sqrt(2 * step_size * (alpha - beta)) * n,