diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 09946e9a3..e871b6211 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -15,11 +15,20 @@ from typing import Callable, NamedTuple import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree from blackjax.mcmc.metrics import EuclideanKineticEnergy from blackjax.types import ArrayTree -__all__ = ["mclachlan", "velocity_verlet", "yoshida"] +__all__ = [ + "mclachlan", + "velocity_verlet", + "yoshida", + "noneuclidean_leapfrog", + "noneuclidean_mclachlan", + "noneuclidean_yoshida", +] class IntegratorState(NamedTuple): @@ -36,213 +45,325 @@ class IntegratorState(NamedTuple): Integrator = Callable[[IntegratorState, float], IntegratorState] - - -def new_integrator_state(logdensity_fn, position, momentum): - logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) - return IntegratorState(position, momentum, logdensity, logdensity_grad) - - -def velocity_verlet( - logdensity_fn: Callable, - kinetic_energy_fn: EuclideanKineticEnergy, -) -> Integrator: - """The velocity Verlet (or Verlet-Störmer) integrator. - - The velocity Verlet is a two-stage palindromic integrator :cite:p:`bou2018geometric` of the form - (a1, b1, a2, b1, a1) with a1 = 0. It is numerically stable for values of - the step size that range between 0 and 2 (when the mass matrix is the - identity). - - While the position (a1 = 0.5) and velocity Verlet are the most commonly used - in samplers, it is known in the numerical computation literature that the value - $a1 \approx 0.1932$ leads to a lower integration error :cite:p:`mclachlan1995numerical,schlick2010molecular`. The authors of :cite:p:`bou2018geometric` - show that the value $a1 \approx 0.21132$ leads to an even higher step acceptance - rate, up to 3 times higher than with the standard position verlet (p.22, Fig.4). - - By choosing the velocity verlet we avoid two computations of the gradient - of the kinetic energy. We are trading accuracy in exchange, and it is not - clear whether this is the right tradeoff. - +GeneralIntegrator = Callable[ + [IntegratorState, float], tuple[IntegratorState, ArrayTree] +] + + +def generalized_two_stage_integrator( + operator1: Callable, + operator2: Callable, + coefficients: list[float], + format_output_fn: Callable = lambda x: x, +): + """Generalized numerical integrator for solving ODEs. + + The generalized integrator performs numerical integration of a ODE system by + alernating between stage 1 and stage 2 updates. + The update scheme is decided by the coefficients, The scheme should be palindromic, + i.e. the coefficients of the update scheme should be symmetric with respect to the + middle of the scheme. + + For instance, for *any* differential equation of the form: + + .. math:: \\frac{d}{dt}f = (O_1+O_2)f + + The leapfrog operator can be seen as approximating :math:`e^{\\epsilon(O_1 + O_2)}` + by :math:`e^{\\epsilon O_1/2}e^{\\epsilon O_2}e^{\\epsilon O_1/2}`. + + In a standard Hamiltonian, the forms of :math:`e^{\\epsilon O_2}` and + :math:`e^{\\epsilon O_1}` are simple, but for other differential equations, + they may be more complex. + + Parameters + ---------- + operator1 + Stage 1 operator, a function that updates the momentum. + operator2 + Stage 2 operator, a function that updates the position. + coefficients + Coefficients of the integrator. + format_output_fn + Function that formats the output of the integrator. + + Returns + ------- + integrator + Integrator function. """ - a1 = 0 - b1 = 0.5 - a2 = 1 - 2 * a1 - logdensity_and_grad_fn = jax.value_and_grad(logdensity_fn) - kinetic_energy_grad_fn = jax.grad(kinetic_energy_fn) - - def one_step(state: IntegratorState, step_size: float) -> IntegratorState: + def one_step(state: IntegratorState, step_size: float): position, momentum, _, logdensity_grad = state - - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, + # auxiliary infomation generated during integration for diagnostics. It is + # updated by the operator1 and operator2 at each call. + momentum_update_info = None + position_update_info = None + for i, coef in enumerate(coefficients[:-1]): + if i % 2 == 0: + momentum, kinetic_grad, momentum_update_info = operator1( + momentum, + logdensity_grad, + step_size, + coef, + momentum_update_info, + is_last_call=False, + ) + else: + ( + position, + logdensity, + logdensity_grad, + position_update_info, + ) = operator2( + position, + kinetic_grad, + step_size, + coef, + position_update_info, + ) + # Separate the last steps to short circuit the computation of the kinetic_grad. + momentum, kinetic_grad, momentum_update_info = operator1( momentum, logdensity_grad, + step_size, + coefficients[-1], + momentum_update_info, + is_last_call=True, ) - - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a2 * step_size * kinetic_grad, + return format_output_fn( position, - kinetic_grad, - ) - - logdensity, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, momentum, + logdensity, logdensity_grad, + kinetic_grad, + position_update_info, + momentum_update_info, ) - return IntegratorState(position, momentum, logdensity, logdensity_grad) - return one_step -def mclachlan( - logdensity_fn: Callable, - kinetic_energy_fn: Callable, -) -> Integrator: - """Two-stage palindromic symplectic integrator derived in :cite:p:`blanes2014numerical`. - - The integrator is of the form (b1, a1, b2, a1, b1). The choice of the parameters - determine both the bound on the integration error and the stability of the - method with respect to the value of `step_size`. The values used here are - the ones derived in :cite:p:`mclachlan1995numerical`; note that :cite:p:`blanes2014numerical` is more focused on stability - and derives different values. +def new_integrator_state(logdensity_fn, position, momentum): + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return IntegratorState(position, momentum, logdensity, logdensity_grad) - """ - b1 = 0.1932 - a1 = 0.5 - b2 = 1 - 2 * b1 +def euclidean_position_update_fn(logdensity_fn: Callable): logdensity_and_grad_fn = jax.value_and_grad(logdensity_fn) - kinetic_energy_grad_fn = jax.grad(kinetic_energy_fn) - - def one_step(state: IntegratorState, step_size: float) -> IntegratorState: - position, momentum, _, logdensity_grad = state - - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) - - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a1 * step_size * kinetic_grad, - position, - kinetic_grad, - ) - _, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b2 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) - - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a1 * step_size * kinetic_grad, + def update( + position: ArrayTree, + kinetic_grad: ArrayTree, + step_size: float, + coef: float, + auxiliary_info=None, + ): + del auxiliary_info + new_position = jax.tree_util.tree_map( + lambda x, grad: x + step_size * coef * grad, position, kinetic_grad, ) + logdensity, logdensity_grad = logdensity_and_grad_fn(new_position) + return new_position, logdensity, logdensity_grad, None - logdensity, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) - - return IntegratorState(position, momentum, logdensity, logdensity_grad) - - return one_step + return update -def yoshida( - logdensity_fn: Callable, - kinetic_energy_fn: Callable, -) -> Integrator: - """Three stages palindromic symplectic integrator derived in :cite:p:`mclachlan1995numerical` - - The integrator is of the form (b1, a1, b2, a2, b2, a1, b1). The choice of - the parameters determine both the bound on the integration error and the - stability of the method with respect to the value of `step_size`. The - values used here are the ones derived in :cite:p:`mclachlan1995numerical` which guarantees a stability - interval length approximately equal to 4.67. - - """ - b1 = 0.11888010966548 - a1 = 0.29619504261126 - b2 = 0.5 - b1 - a2 = 1 - 2 * a1 - - logdensity_and_grad_fn = jax.value_and_grad(logdensity_fn) +def euclidean_momentum_update_fn(kinetic_energy_fn: EuclideanKineticEnergy): kinetic_energy_grad_fn = jax.grad(kinetic_energy_fn) - def one_step(state: IntegratorState, step_size: float) -> IntegratorState: - position, momentum, _, logdensity_grad = state - - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, + def update( + momentum: ArrayTree, + logdensity_grad: ArrayTree, + step_size: float, + coef: float, + auxiliary_info=None, + is_last_call=False, + ): + del auxiliary_info + new_momentum = jax.tree_util.tree_map( + lambda x, grad: x + step_size * coef * grad, momentum, logdensity_grad, ) + if is_last_call: + return new_momentum, None, None + kinetic_grad = kinetic_energy_grad_fn(new_momentum) + return new_momentum, kinetic_grad, None + + return update + + +def format_euclidean_state_output( + position, + momentum, + logdensity, + logdensity_grad, + kinetic_grad, + position_update_info, + momentum_update_info, +): + del kinetic_grad, position_update_info, momentum_update_info + return IntegratorState(position, momentum, logdensity, logdensity_grad) - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a1 * step_size * kinetic_grad, - position, - kinetic_grad, - ) - _, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b2 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) +def generate_euclidean_integrator(cofficients): + """Generate symplectic integrator for solving a Hamiltonian system. - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a2 * step_size * kinetic_grad, - position, - kinetic_grad, - ) + The resulting integrator is volume-preserve and preserves the symplectic structure + of phase space. + """ - _, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b2 * step_size * logdensity_grad, - momentum, - logdensity_grad, + def euclidean_integrator( + logdensity_fn: Callable, kinetic_energy_fn: EuclideanKineticEnergy + ) -> Integrator: + position_update_fn = euclidean_position_update_fn(logdensity_fn) + momentum_update_fn = euclidean_momentum_update_fn(kinetic_energy_fn) + one_step = generalized_two_stage_integrator( + momentum_update_fn, + position_update_fn, + cofficients, + format_output_fn=format_euclidean_state_output, ) + return one_step + + return euclidean_integrator + + +""" +The velocity Verlet (or Verlet-Störmer) integrator. + +The velocity Verlet is a two-stage palindromic integrator :cite:p:`bou2018geometric` +of the form (a1, b1, a2, b1, a1) with a1 = 0. It is numerically stable for values of +the step size that range between 0 and 2 (when the mass matrix is the identity). + +While the position (a1 = 0.5) and velocity Verlet are the most commonly used +in samplers, it is known in the numerical computation literature that the value +$a1 \approx 0.1932$ leads to a lower integration error :cite:p:`mclachlan1995numerical,schlick2010molecular`. +The authors of :cite:p:`bou2018geometric` show that the value $a1 \approx 0.21132$ +leads to an even higher step acceptance rate, up to 3 times higher +than with the standard position verlet (p.22, Fig.4). + +By choosing the velocity verlet we avoid two computations of the gradient +of the kinetic energy. We are trading accuracy in exchange, and it is not +clear whether this is the right tradeoff. +""" +velocity_verlet_cofficients = [0.5, 1.0, 0.5] +velocity_verlet = generate_euclidean_integrator(velocity_verlet_cofficients) + +""" +Two-stage palindromic symplectic integrator derived in :cite:p:`blanes2014numerical`. + +The integrator is of the form (b1, a1, b2, a1, b1). The choice of the parameters +determine both the bound on the integration error and the stability of the +method with respect to the value of `step_size`. The values used here are +the ones derived in :cite:p:`mclachlan1995numerical`; note that :cite:p:`blanes2014numerical` +is more focused on stability and derives different values. + +Also known as the minimal norm integrator. +""" +b1 = 0.1931833275037836 +a1 = 0.5 +b2 = 1 - 2 * b1 +mclachlan_cofficients = [b1, a1, b2, a1, b1] +mclachlan = generate_euclidean_integrator(mclachlan_cofficients) + +""" +Three stages palindromic symplectic integrator derived in :cite:p:`mclachlan1995numerical` + +The integrator is of the form (b1, a1, b2, a2, b2, a1, b1). The choice of +the parameters determine both the bound on the integration error and the +stability of the method with respect to the value of `step_size`. The +values used here are the ones derived in :cite:p:`mclachlan1995numerical` which +guarantees a stability interval length approximately equal to 4.67. +""" +b1 = 0.11888010966548 +a1 = 0.29619504261126 +b2 = 0.5 - b1 +a2 = 1 - 2 * a1 +yoshida_cofficients = [b1, a1, b2, a2, b2, a1, b1] +yoshida = generate_euclidean_integrator(yoshida_cofficients) + + +# Intergrators with non Euclidean updates +def _normalized_flatten_array(x, tol=1e-13): + norm = jnp.linalg.norm(x) + return jnp.where(norm > tol, x / norm, x), norm + + +def esh_dynamics_momentum_update_one_step( + momentum: ArrayTree, + logdensity_grad: ArrayTree, + step_size: float, + coef: float, + previous_kinetic_energy_change=None, + is_last_call=False, +): + """Momentum update based on Esh dynamics. + + The momentum updating map of the esh dynamics as derived in :cite:p:`steeg2021hamiltonian` + There are no exponentials e^delta, which prevents overflows when the gradient norm + is large. + """ - kinetic_grad = kinetic_energy_grad_fn(momentum) - position = jax.tree_util.tree_map( - lambda position, kinetic_grad: position + a1 * step_size * kinetic_grad, - position, - kinetic_grad, + flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) + flatten_momentum, _ = ravel_pytree(momentum) + dims = flatten_momentum.shape[0] + normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) + momentum_proj = jnp.dot(flatten_momentum, normalized_gradient) + delta = step_size * coef * gradient_norm / (dims - 1) + zeta = jnp.exp(-delta) + new_momentum_raw = ( + normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta)) + + 2 * zeta * flatten_momentum + ) + new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) + next_momentum = unravel_fn(new_momentum_normalized) + kinetic_energy_change = ( + delta + - jnp.log(2) + + jnp.log(1 + momentum_proj + (1 - momentum_proj) * zeta**2) + ) + if previous_kinetic_energy_change is not None: + kinetic_energy_change += previous_kinetic_energy_change + if is_last_call: + kinetic_energy_change *= dims - 1 + return next_momentum, next_momentum, kinetic_energy_change + + +def format_noneuclidean_state_output( + position, + momentum, + logdensity, + logdensity_grad, + kinetic_grad, + position_update_info, + momentum_update_info, +): + del kinetic_grad, position_update_info + return ( + IntegratorState(position, momentum, logdensity, logdensity_grad), + momentum_update_info, + ) + + +def generate_noneuclidean_integrator(cofficients): + def noneuclidean_integrator( + logdensity_fn: Callable, *args, **kwargs + ) -> GeneralIntegrator: + position_update_fn = euclidean_position_update_fn(logdensity_fn) + one_step = generalized_two_stage_integrator( + esh_dynamics_momentum_update_one_step, + position_update_fn, + cofficients, + format_output_fn=format_noneuclidean_state_output, ) + return one_step - logdensity, logdensity_grad = logdensity_and_grad_fn(position) - momentum = jax.tree_util.tree_map( - lambda momentum, logdensity_grad: momentum - + b1 * step_size * logdensity_grad, - momentum, - logdensity_grad, - ) + return noneuclidean_integrator - return IntegratorState(position, momentum, logdensity, logdensity_grad) - return one_step +noneuclidean_leapfrog = generate_noneuclidean_integrator(velocity_verlet_cofficients) +noneuclidean_mclachlan = generate_noneuclidean_integrator(mclachlan_cofficients) +noneuclidean_yoshida = generate_noneuclidean_integrator(yoshida_cofficients) diff --git a/blackjax/util.py b/blackjax/util.py index 1a7ebcd09..a3a7226a6 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -82,6 +82,28 @@ def generate_gaussian_noise( return unravel_fn(mu + linear_map(sigma, sample)) +def generate_unit_vector( + rng_key: PRNGKey, + position: ArrayLikeTree, +) -> Array: + """Generate a random unit vector with output structure that match a given PyTree. + + Parameters + ---------- + rng_key: + The pseudo-random number generator key used to generate random numbers. + position: + PyTree that the structure the output should to match. + + Returns + ------- + Random unit vector that match the structure of position. + """ + p, unravel_fn = ravel_pytree(position) + sample = normal(rng_key, shape=p.shape, dtype=p.dtype) + return unravel_fn(sample / jnp.linalg.norm(sample)) + + def pytree_size(pytree: ArrayLikeTree) -> int: """Return the dimension of the flatten PyTree.""" return sum(jnp.size(value) for value in tree_leaves(pytree)) diff --git a/docs/refs.bib b/docs/refs.bib index f5015ccb9..1b6485809 100644 --- a/docs/refs.bib +++ b/docs/refs.bib @@ -360,3 +360,21 @@ @inproceedings{hoffman2021adaptive year={2021}, organization={PMLR} } + +@misc{steeg2021hamiltonian, + title={Hamiltonian Dynamics with Non-Newtonian Momentum for Rapid Sampling}, + author={Greg Ver Steeg and Aram Galstyan}, + year={2021}, + eprint={2111.02434}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} + +@misc{robnik2023microcanonical, + title={Microcanonical Hamiltonian Monte Carlo}, + author={Jakob Robnik and G. Bruno De Luca and Eva Silverstein and Uroš Seljak}, + year={2023}, + eprint={2212.08549}, + archivePrefix={arXiv}, + primaryClass={stat.CO} +} diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index 68f1dbd88..2f5020d00 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -3,9 +3,14 @@ import chex import jax import jax.numpy as jnp +import jax.scipy.stats as stats +import numpy as np from absl.testing import absltest, parameterized +from jax.flatten_util import ravel_pytree import blackjax.mcmc.integrators as integrators +from blackjax.mcmc.integrators import esh_dynamics_momentum_update_one_step +from blackjax.util import generate_unit_vector def HarmonicOscillator(inv_mass_matrix, k=1.0, m=1.0): @@ -47,13 +52,38 @@ def kinetic_energy(p): return neg_potential_energy, kinetic_energy -algorithms = { - "velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4}, - "mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-5}, - "yoshida": {"algorithm": integrators.yoshida, "precision": 1e-6}, -} +def MultivariateNormal(inv_mass_matrix): + """Potential and kinetic energy for a multivariate normal distribution.""" + + def log_density(q): + q, _ = ravel_pytree(q) + return stats.multivariate_normal.logpdf(q, jnp.zeros_like(q), inv_mass_matrix) + + def kinetic_energy(p): + p, _ = ravel_pytree(p) + return 0.5 * p.T @ inv_mass_matrix @ p + + return log_density, kinetic_energy +mvnormal_position_init = { + "a": 0.0, + "b": jnp.asarray([1.0, 2.0, 3.0]), + "c": jnp.ones((2, 1)), +} +_, unravel_fn = ravel_pytree(mvnormal_position_init) +key0, key1 = jax.random.split(jax.random.key(52)) +mvnormal_momentum_init = unravel_fn(jax.random.normal(key0, (6,))) +a = jax.random.normal(key1, (6, 6)) +cov = jnp.matmul(a.T, a) +# Validated numerically +mvnormal_position_end = unravel_fn( + jnp.asarray([0.38887993, 0.85231394, 2.7879136, 3.0339851, 0.5856687, 1.9291426]) +) +mvnormal_momentum_end = unravel_fn( + jnp.asarray([0.46576163, 0.23854092, 1.2518811, -0.35647452, -0.742138, 1.2552949]) +) + examples = { "free_fall": { "model": FreeFall, @@ -85,6 +115,25 @@ def kinetic_energy(p): "p_final": {"x": 0.0, "y": 1.0}, "inv_mass_matrix": jnp.array([1.0, 1.0]), }, + "multivariate_normal": { + "model": MultivariateNormal, + "num_steps": 16, + "step_size": 0.005, + "q_init": mvnormal_position_init, + "p_init": mvnormal_momentum_init, + "q_final": mvnormal_position_end, + "p_final": mvnormal_momentum_end, + "inv_mass_matrix": cov, + }, +} + +algorithms = { + "velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4}, + "mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-5}, + "yoshida": {"algorithm": integrators.yoshida, "precision": 1e-6}, + "noneuclidean_leapfrog": {"algorithm": integrators.noneuclidean_leapfrog}, + "noneuclidean_mclachlan": {"algorithm": integrators.noneuclidean_mclachlan}, + "noneuclidean_yoshida": {"algorithm": integrators.noneuclidean_yoshida}, } @@ -100,11 +149,20 @@ class IntegratorTest(chex.TestCase): @chex.all_variants(with_pmap=False) @parameterized.parameters( itertools.product( - ["free_fall", "harmonic_oscillator", "planetary_motion"], - ["velocity_verlet", "mclachlan", "yoshida"], + [ + "free_fall", + "harmonic_oscillator", + "planetary_motion", + "multivariate_normal", + ], + [ + "velocity_verlet", + "mclachlan", + "yoshida", + ], ) ) - def test_integrator(self, example_name, integrator_name): + def test_euclidean_integrator(self, example_name, integrator_name): integrator = algorithms[integrator_name] example = examples[example_name] @@ -120,6 +178,7 @@ def test_integrator(self, example_name, integrator_name): initial_state = integrators.IntegratorState( q, p, neg_potential(q), jax.grad(neg_potential)(q) ) + final_state = jax.lax.fori_loop( 0, example["num_steps"], @@ -137,6 +196,76 @@ def test_integrator(self, example_name, integrator_name): ) self.assertAlmostEqual(energy, new_energy, delta=integrator["precision"]) + @chex.all_variants(with_pmap=False) + @parameterized.parameters([3, 5]) + def test_esh_momentum_update(self, dims): + """ + Test the numerically efficient version of the momentum update currently + implemented match the naive implementation according to the Equation 16 in + :cite:p:`robnik2023microcanonical` + """ + step_size = 1e-3 + key0, key1 = jax.random.split(jax.random.key(62)) + gradient = jax.random.uniform(key0, shape=(dims,)) + momentum = jax.random.uniform(key1, shape=(dims,)) + momentum /= jnp.linalg.norm(momentum) + + # Navie implementation + gradient_norm = jnp.linalg.norm(gradient) + gradient_normalized = gradient / gradient_norm + delta = step_size * gradient_norm / (dims - 1) + next_momentum = ( + momentum + + gradient_normalized + * ( + jnp.sinh(delta) + + jnp.dot(gradient_normalized, momentum * (jnp.cosh(delta) - 1)) + ) + ) / (jnp.cosh(delta) + jnp.dot(gradient_normalized, momentum * jnp.sinh(delta))) + + # Efficient implementation + update_stable = self.variant(esh_dynamics_momentum_update_one_step) + next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) + np.testing.assert_array_almost_equal(next_momentum, next_momentum1) + + @chex.all_variants(with_pmap=False) + @parameterized.parameters( + [ + "noneuclidean_leapfrog", + "noneuclidean_mclachlan", + "noneuclidean_yoshida", + ], + ) + def test_noneuclidean_integrator(self, integrator_name): + integrator = algorithms[integrator_name] + cov = jnp.asarray([[1.0, 0.5], [0.5, 2.0]]) + logdensity_fn = lambda x: stats.multivariate_normal.logpdf( + x, jnp.zeros([2]), cov + ) + + step = self.variant(integrator["algorithm"](logdensity_fn)) + + rng = jax.random.key(4263456) + key0, key1 = jax.random.split(rng, 2) + position_init = jax.random.normal(key0, (2,)) + momentum_init = generate_unit_vector(key1, position_init) + step_size = 0.0001 + initial_state = integrators.new_integrator_state( + logdensity_fn, position_init, momentum_init + ) + + final_state, kinetic_energy_change = jax.lax.scan( + lambda state, _: step(state, step_size), + initial_state, + xs=None, + length=15, + ) + + # Check the conservation of energy. + potential_energy_change = final_state.logdensity - initial_state.logdensity + energy_change = kinetic_energy_change[-1] + potential_energy_change + self.assertAlmostEqual(energy_change, 0, delta=1e-3) + if __name__ == "__main__": absltest.main()