diff --git a/blackjax/mcmc/ghmc.py b/blackjax/mcmc/ghmc.py index a04ce0641..5f8ab89a7 100644 --- a/blackjax/mcmc/ghmc.py +++ b/blackjax/mcmc/ghmc.py @@ -16,6 +16,7 @@ import jax import jax.numpy as jnp +from jax.flatten_util import ravel_pytree import blackjax.mcmc.hmc as hmc import blackjax.mcmc.integrators as integrators @@ -129,8 +130,8 @@ def kernel( """ - flat_inverse_scale = jax.flatten_util.ravel_pytree(momentum_inverse_scale)[0] - momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean( + flat_inverse_scale = ravel_pytree(momentum_inverse_scale)[0] + momentum_generator, kinetic_energy_fn, *_ = metrics.gaussian_euclidean( flat_inverse_scale**2 ) @@ -248,6 +249,10 @@ def as_top_level_api( A PyTree of the same structure as the target PyTree (position) with the values used for as a step size for each dimension of the target space in the velocity verlet integrator. + momentum_inverse_scale + Pytree with the same structure as the targeted position variable + specifying the per dimension inverse scaling transformation applied + to the persistent momentum variable prior to the integration step. alpha The value defining the persistence of the momentum variable. delta diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index 1368a8441..4e079714b 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -30,13 +30,13 @@ """ from typing import Callable, NamedTuple, Optional, Protocol, Union +import jax import jax.numpy as jnp import jax.scipy as jscipy from jax.flatten_util import ravel_pytree -from jax.scipy import stats as sp_stats -from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey -from blackjax.util import generate_gaussian_noise +from blackjax.types import Array, ArrayLikeTree, ArrayTree, Numeric, PRNGKey +from blackjax.util import generate_gaussian_noise, linear_map __all__ = ["default_metric", "gaussian_euclidean", "gaussian_riemannian"] @@ -44,7 +44,7 @@ class KineticEnergy(Protocol): def __call__( self, momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None - ) -> float: + ) -> Numeric: ... @@ -60,10 +60,18 @@ def __call__( ... +class Scale(Protocol): + def __call__( + self, position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree + ) -> ArrayLikeTree: + ... + + class Metric(NamedTuple): sample_momentum: Callable[[PRNGKey, ArrayLikeTree], ArrayLikeTree] kinetic_energy: KineticEnergy check_turning: CheckTurning + scale: Scale MetricTypes = Union[Metric, Array, Callable[[ArrayLikeTree], Array]] @@ -128,46 +136,19 @@ def gaussian_euclidean( itself given the values of the momentum along the trajectory. """ - ndim = jnp.ndim(inverse_mass_matrix) # type: ignore[arg-type] - shape = jnp.shape(inverse_mass_matrix)[:1] # type: ignore[arg-type] - - if ndim == 1: # diagonal mass matrix - mass_matrix_sqrt = jnp.sqrt(jnp.reciprocal(inverse_mass_matrix)) - matmul = jnp.multiply - - elif ndim == 2: - # inverse mass matrix can be factored into L*L.T. We want the cholesky - # factor (inverse of L.T) of the mass matrix. - L = jscipy.linalg.cholesky(inverse_mass_matrix, lower=True) - identity = jnp.identity(shape[0]) - mass_matrix_sqrt = jscipy.linalg.solve_triangular( - L, identity, lower=True, trans=True - ) - # Note that mass_matrix_sqrt is a upper triangular matrix here, with - # jscipy.linalg.inv(mass_matrix_sqrt @ mass_matrix_sqrt.T) - # == inverse_mass_matrix - # An alternative is to compute directly the cholesky factor of the inverse mass - # matrix - # mass_matrix_sqrt = jscipy.linalg.cholesky( - # jscipy.linalg.inv(inverse_mass_matrix), lower=True) - # which the result would instead be a lower triangular matrix. - matmul = jnp.matmul - - else: - raise ValueError( - "The mass matrix has the wrong number of dimensions:" - f" expected 1 or 2, got {ndim}." - ) + mass_matrix_sqrt, inv_mass_matrix_sqrt, diag = _format_covariance( + inverse_mass_matrix, is_inv=True + ) def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayTree: return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt) def kinetic_energy( momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None - ) -> float: + ) -> Numeric: del position momentum, _ = ravel_pytree(momentum) - velocity = matmul(inverse_mass_matrix, momentum) + velocity = linear_map(inverse_mass_matrix, momentum) kinetic_energy_val = 0.5 * jnp.dot(velocity, momentum) return kinetic_energy_val @@ -196,8 +177,8 @@ def is_turning( m_right, _ = ravel_pytree(momentum_right) m_sum, _ = ravel_pytree(momentum_sum) - velocity_left = matmul(inverse_mass_matrix, m_left) - velocity_right = matmul(inverse_mass_matrix, m_right) + velocity_left = linear_map(inverse_mass_matrix, m_left) + velocity_right = linear_map(inverse_mass_matrix, m_right) # rho = m_sum rho = m_sum - (m_right + m_left) / 2 @@ -205,7 +186,37 @@ def is_turning( turning_at_right = jnp.dot(velocity_right, rho) <= 0 return turning_at_left | turning_at_right - return Metric(momentum_generator, kinetic_energy, is_turning) + def scale( + position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree + ) -> ArrayLikeTree: + """Scale elements by the mass matrix. + + Parameters + ---------- + position + The current position. Not used in this metric. + elements + Elements to scale + invs + Whether to scale the elements by the inverse mass matrix or the mass matrix. + If True, the element is scaled by the inverse square root mass matrix, i.e., elem <- (M^{1/2})^{-1} elem. + Same pytree structure as `elements`. + + Returns + ------- + scaled_elements + The scaled elements. + """ + + ravelled_element, unravel_fn = ravel_pytree(element) + scaled = jax.lax.cond( + inv, + lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element), + lambda: linear_map(mass_matrix_sqrt, ravelled_element), + ) + return unravel_fn(scaled) + + return Metric(momentum_generator, kinetic_energy, is_turning, scale) def gaussian_riemannian( @@ -213,22 +224,13 @@ def gaussian_riemannian( ) -> Metric: def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayLikeTree: mass_matrix = mass_matrix_fn(position) - ndim = jnp.ndim(mass_matrix) - if ndim == 1: - mass_matrix_sqrt = jnp.sqrt(mass_matrix) - elif ndim == 2: - mass_matrix_sqrt = jscipy.linalg.cholesky(mass_matrix, lower=True) - else: - raise ValueError( - "The mass matrix has the wrong number of dimensions:" - f" expected 1 or 2, got {jnp.ndim(mass_matrix)}." - ) + mass_matrix_sqrt, *_ = _format_covariance(mass_matrix, is_inv=False) return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt) def kinetic_energy( momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None - ) -> float: + ) -> Numeric: if position is None: raise ValueError( "A Reinmannian kinetic energy function must be called with the " @@ -238,18 +240,11 @@ def kinetic_energy( momentum, _ = ravel_pytree(momentum) mass_matrix = mass_matrix_fn(position) - ndim = jnp.ndim(mass_matrix) - if ndim == 1: - return -jnp.sum(sp_stats.norm.logpdf(momentum, 0.0, jnp.sqrt(mass_matrix))) - elif ndim == 2: - return -sp_stats.multivariate_normal.logpdf( - momentum, jnp.zeros_like(momentum), mass_matrix - ) - else: - raise ValueError( - "The mass matrix has the wrong number of dimensions:" - f" expected 1 or 2, got {jnp.ndim(mass_matrix)}." - ) + sqrt_mass_matrix, inv_sqrt_mass_matrix, diag = _format_covariance( + mass_matrix, is_inv=False + ) + + return _energy(momentum, 0, sqrt_mass_matrix, inv_sqrt_mass_matrix.T, diag) def is_turning( momentum_left: ArrayLikeTree, @@ -283,4 +278,69 @@ def is_turning( # turning_at_right = jnp.dot(velocity_right, rho) <= 0 # return turning_at_left | turning_at_right - return Metric(momentum_generator, kinetic_energy, is_turning) + def scale( + position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree + ) -> ArrayLikeTree: + """Scale elements by the mass matrix. + + Parameters + ---------- + position + The current position. + + Returns + ------- + scaled_elements + The scaled elements. + """ + mass_matrix = mass_matrix_fn(position) + mass_matrix_sqrt, inv_mass_matrix_sqrt, diag = _format_covariance( + mass_matrix, is_inv=False + ) + ravelled_element, unravel_fn = ravel_pytree(element) + scaled = jax.lax.cond( + inv, + lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element), + lambda: linear_map(mass_matrix_sqrt, ravelled_element), + ) + return unravel_fn(scaled) + + return Metric(momentum_generator, kinetic_energy, is_turning, scale) + + +def _format_covariance(cov: Array, is_inv): + ndim = jnp.ndim(cov) + if ndim == 1: + cov_sqrt = jnp.sqrt(cov) + inv_cov_sqrt = 1 / cov_sqrt + diag = lambda x: x + if is_inv: + inv_cov_sqrt, cov_sqrt = cov_sqrt, inv_cov_sqrt + elif ndim == 2: + identity = jnp.identity(cov.shape[0]) + if is_inv: + inv_cov_sqrt = jscipy.linalg.cholesky(cov, lower=True) + cov_sqrt = jscipy.linalg.solve_triangular( + inv_cov_sqrt, identity, lower=True, trans=True + ) + else: + cov_sqrt = jscipy.linalg.cholesky(cov, lower=False).T + inv_cov_sqrt = jscipy.linalg.solve_triangular( + cov_sqrt, identity, lower=True, trans=True + ) + + diag = lambda x: jnp.diag(x) + + else: + raise ValueError( + "The mass matrix has the wrong number of dimensions:" + f" expected 1 or 2, got {jnp.ndim(cov)}." + ) + return cov_sqrt, inv_cov_sqrt, diag + + +def _energy(x, mean, cov_sqrt, inv_cov_sqrt, diag): + d = x.shape[0] + z = linear_map(inv_cov_sqrt, x - mean) + const = jnp.sum(jnp.log(diag(cov_sqrt))) + d / 2 * jnp.log(2 * jnp.pi) + return 0.5 * jnp.sum(z**2) + const diff --git a/blackjax/mcmc/periodic_orbital.py b/blackjax/mcmc/periodic_orbital.py index 61625a0b8..b996205e8 100644 --- a/blackjax/mcmc/periodic_orbital.py +++ b/blackjax/mcmc/periodic_orbital.py @@ -172,7 +172,7 @@ def kernel( """ - momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean( + momentum_generator, kinetic_energy_fn, *_ = metrics.gaussian_euclidean( inverse_mass_matrix ) bijection_fn = bijection(logdensity_fn, kinetic_energy_fn) diff --git a/blackjax/types.py b/blackjax/types.py index 5a3b59f07..4b23fcd22 100644 --- a/blackjax/types.py +++ b/blackjax/types.py @@ -43,3 +43,7 @@ class WelfordAlgorithmState(NamedTuple): #: JAX PRNGKey PRNGKey = jax.Array + +#: JAX Scalar types +Scalar = Union[float, int] +Numeric = Union[jax.Array, Scalar] diff --git a/tests/mcmc/test_metrics.py b/tests/mcmc/test_metrics.py index f806a375c..0791f3cb1 100644 --- a/tests/mcmc/test_metrics.py +++ b/tests/mcmc/test_metrics.py @@ -8,6 +8,90 @@ from blackjax.mcmc import metrics +class CovarianceFormattingTest(chex.TestCase): + def setUp(self): + super().setUp() + self.key = random.key(0) + self.dtype = "float32" + + @parameterized.named_parameters( + {"testcase_name": "0d", "shape": (), "is_inv": False}, + {"testcase_name": "0d_inv", "shape": (), "is_inv": True}, + {"testcase_name": "3d", "shape": (1, 2, 3), "is_inv": False}, + {"testcase_name": "3d_inv", "shape": (1, 2, 3), "is_inv": True}, + ) + def test_invalid(self, shape, is_inv): + """Test formatting raises error for invalid shapes""" + mass_matrix = jnp.zeros(shape=shape) + with self.assertRaisesRegex( + ValueError, "The mass matrix has the wrong number of dimensions" + ): + metrics._format_covariance(mass_matrix, is_inv) + + @parameterized.named_parameters( + {"testcase_name": "inv", "is_inv": True}, + {"testcase_name": "no_inv", "is_inv": False}, + ) + def test_dim_1(self, is_inv): + """Test formatting for 1D mass matrix""" + mass_matrix = jnp.asarray([1 / 4], dtype=self.dtype) + mass_matrix_sqrt, inv_mass_matrix_sqrt, diag = metrics._format_covariance( + mass_matrix, is_inv + ) + if is_inv: + chex.assert_trees_all_close(inv_mass_matrix_sqrt, mass_matrix**0.5) + chex.assert_trees_all_close(mass_matrix_sqrt, mass_matrix**-0.5) + else: + chex.assert_trees_all_close(mass_matrix_sqrt, mass_matrix**0.5) + chex.assert_trees_all_close(inv_mass_matrix_sqrt, mass_matrix**-0.5) + + chex.assert_trees_all_close(diag(mass_matrix), mass_matrix) + + @parameterized.named_parameters( + {"testcase_name": "inv", "is_inv": True}, + {"testcase_name": "no_inv", "is_inv": False}, + ) + def test_dim_2(self, is_inv): + """Test formatting for 2D mass matrix""" + mass_matrix = jnp.asarray([[2 / 3, 0.5], [0.5, 3 / 4]], dtype=self.dtype) + mass_matrix_sqrt, inv_mass_matrix_sqrt, diag = metrics._format_covariance( + mass_matrix, is_inv + ) + if is_inv: + chex.assert_trees_all_close( + mass_matrix_sqrt @ mass_matrix_sqrt.T, linalg.inv(mass_matrix) + ) + chex.assert_trees_all_close( + inv_mass_matrix_sqrt @ inv_mass_matrix_sqrt.T, mass_matrix + ) + + else: + chex.assert_trees_all_close( + mass_matrix_sqrt @ mass_matrix_sqrt.T, mass_matrix + ) + chex.assert_trees_all_close( + inv_mass_matrix_sqrt @ inv_mass_matrix_sqrt.T, linalg.inv(mass_matrix) + ) + + def test_dim2_inv_and_not_inv_agree(self): + mass_matrix = jnp.asarray([[2 / 3, 0.5], [0.5, 3 / 4]], dtype=self.dtype) + mass_matrix_sqrt, inv_mass_matrix_sqrt, _ = metrics._format_covariance( + mass_matrix, False + ) + mass_matrix_sqrt_inv, inv_mass_matrix_sqrt_inv, _ = metrics._format_covariance( + linalg.inv(mass_matrix), True + ) + + chex.assert_trees_all_close( + mass_matrix_sqrt @ mass_matrix_sqrt.T, + mass_matrix_sqrt_inv @ mass_matrix_sqrt_inv.T, + ) + chex.assert_trees_all_close( + inv_mass_matrix_sqrt @ inv_mass_matrix_sqrt.T, + inv_mass_matrix_sqrt_inv @ inv_mass_matrix_sqrt_inv.T, + ) + + class GaussianEuclideanMetricsTest(chex.TestCase): def setUp(self): super().setUp() @@ -30,7 +114,9 @@ def test_gaussian_euclidean_ndim_invalid(self, shape): def test_gaussian_euclidean_dim_1(self): """Test Gaussian Euclidean Function with ndim 1""" inverse_mass_matrix = jnp.asarray([1 / 4], dtype=self.dtype) - momentum, kinetic_energy, _ = metrics.gaussian_euclidean(inverse_mass_matrix) + momentum, kinetic_energy, _, scale = metrics.gaussian_euclidean( + inverse_mass_matrix + ) arbitrary_position = jnp.asarray([12345], dtype=self.dtype) momentum_val = self.variant(momentum)(self.key, arbitrary_position) @@ -45,18 +131,30 @@ def test_gaussian_euclidean_dim_1(self): assert momentum_val == expected_momentum_val assert kinetic_energy_val == expected_kinetic_energy_val + inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) + scaled_momentum = scale(arbitrary_position, momentum_val, False) + + expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix) + expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix) + + chex.assert_trees_all_close(inv_scaled_momentum, expected_inv_scaled_momentum) + chex.assert_trees_all_close(scaled_momentum, expected_scaled_momentum) + @chex.all_variants(with_pmap=False) def test_gaussian_euclidean_dim_2(self): """Test Gaussian Euclidean Function with ndim 2""" inverse_mass_matrix = jnp.asarray( - [[1 / 9, 0.5], [0.5, 1 / 4]], dtype=self.dtype + [[2 / 3, 0.5], [0.5, 3 / 4]], dtype=self.dtype + ) + momentum, kinetic_energy, _, scale = metrics.gaussian_euclidean( + inverse_mass_matrix ) - momentum, kinetic_energy, _ = metrics.gaussian_euclidean(inverse_mass_matrix) arbitrary_position = jnp.asarray([12345, 23456], dtype=self.dtype) momentum_val = self.variant(momentum)(self.key, arbitrary_position) - L_inv = linalg.cholesky(linalg.inv(inverse_mass_matrix), lower=True) + L_inv = linalg.inv(linalg.cholesky(inverse_mass_matrix, lower=False)) + expected_momentum_val = L_inv @ random.normal(self.key, shape=(2,)) kinetic_energy_val = self.variant(kinetic_energy)(momentum_val) @@ -66,6 +164,15 @@ def test_gaussian_euclidean_dim_2(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) + inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) + scaled_momentum = scale(arbitrary_position, momentum_val, False) + + expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val + expected_scaled_momentum = L_inv @ momentum_val + + chex.assert_trees_all_close(inv_scaled_momentum, expected_inv_scaled_momentum) + chex.assert_trees_all_close(scaled_momentum, expected_scaled_momentum) + class GaussianRiemannianMetricsTest(chex.TestCase): def setUp(self): @@ -99,7 +206,9 @@ def test_gaussian_riemannian_value_errors(self, shape): def test_gaussian_riemannian_dim_1(self): inverse_mass_matrix = jnp.asarray([1 / 4], dtype=self.dtype) mass_matrix = jnp.asarray([4.0], dtype=self.dtype) - momentum, kinetic_energy, _ = metrics.gaussian_riemannian(lambda _: mass_matrix) + momentum, kinetic_energy, _, scale = metrics.gaussian_riemannian( + lambda _: mass_matrix + ) arbitrary_position = jnp.asarray([12345], dtype=self.dtype) momentum_val = self.variant(momentum)(self.key, arbitrary_position) @@ -114,16 +223,26 @@ def test_gaussian_riemannian_dim_1(self): expected_kinetic_energy_val = 0.5 * velocity * momentum_val expected_kinetic_energy_val += 0.5 * jnp.sum(jnp.log(2 * jnp.pi * mass_matrix)) - assert momentum_val == expected_momentum_val - assert kinetic_energy_val == expected_kinetic_energy_val + np.testing.assert_allclose(expected_momentum_val, momentum_val) + np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) + + inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) + scaled_momentum = scale(arbitrary_position, momentum_val, False) + expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix) + expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix) + + chex.assert_trees_all_close(inv_scaled_momentum, expected_inv_scaled_momentum) + chex.assert_trees_all_close(scaled_momentum, expected_scaled_momentum) @chex.all_variants(with_pmap=False) - def test_gaussian_euclidean_dim_2(self): + def test_gaussian_riemannian_dim_2(self): inverse_mass_matrix = jnp.asarray( - [[1 / 9, 0.5], [0.5, 1 / 4]], dtype=self.dtype + [[2 / 3, 0.5], [0.5, 3 / 4]], dtype=self.dtype ) mass_matrix = jnp.linalg.inv(inverse_mass_matrix) - momentum, kinetic_energy, _ = metrics.gaussian_riemannian(lambda _: mass_matrix) + momentum, kinetic_energy, _, scale = metrics.gaussian_riemannian( + lambda _: mass_matrix + ) arbitrary_position = jnp.asarray([12345, 23456], dtype=self.dtype) momentum_val = self.variant(momentum)(self.key, arbitrary_position) @@ -131,6 +250,10 @@ def test_gaussian_euclidean_dim_2(self): L_inv = linalg.cholesky(linalg.inv(inverse_mass_matrix), lower=True) expected_momentum_val = L_inv @ random.normal(self.key, shape=(2,)) + sqrt_mass_matrix, inv_sqrt_mass_matrix, _ = metrics._format_covariance( + inverse_mass_matrix, True + ) + kinetic_energy_val = self.variant(kinetic_energy)( momentum_val, position=arbitrary_position ) @@ -142,6 +265,14 @@ def test_gaussian_euclidean_dim_2(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) + inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) + scaled_momentum = scale(arbitrary_position, momentum_val, False) + expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val + expected_scaled_momentum = L_inv @ momentum_val + + chex.assert_trees_all_close(inv_scaled_momentum, expected_inv_scaled_momentum) + chex.assert_trees_all_close(scaled_momentum, expected_scaled_momentum) + if __name__ == "__main__": absltest.main() diff --git a/tests/mcmc/test_trajectory.py b/tests/mcmc/test_trajectory.py index c8a5aa908..e93280400 100644 --- a/tests/mcmc/test_trajectory.py +++ b/tests/mcmc/test_trajectory.py @@ -32,6 +32,7 @@ def test_dynamic_progressive_integration_divergence( momentum_generator, kinetic_energy_fn, uturn_check_fn, + _, ) = metrics.gaussian_euclidean(inverse_mass_matrix) integrator = integrators.velocity_verlet(logdensity_fn, kinetic_energy_fn) @@ -83,6 +84,7 @@ def logdensity_fn(x): momentum_generator, kinetic_energy_fn, uturn_check_fn, + _, ) = metrics.gaussian_euclidean(inverse_mass_matrix) integrator = integrators.velocity_verlet(logdensity_fn, kinetic_energy_fn) @@ -211,6 +213,7 @@ def logdensity_fn(x): momentum_generator, kinetic_energy_fn, uturn_check_fn, + _, ) = metrics.gaussian_euclidean(inverse_mass_matrix) integrator = integrators.velocity_verlet(logdensity_fn, kinetic_energy_fn) @@ -266,7 +269,7 @@ def test_static_integration_variable_num_steps(self): ( momentum_generator, kinetic_energy_fn, - _, + *_, ) = metrics.gaussian_euclidean(inverse_mass_matrix) initial_state = integrators.new_integrator_state( logdensity_fn, position, momentum_generator(rng_key, position) diff --git a/tests/mcmc/test_uturn.py b/tests/mcmc/test_uturn.py index 3dc730565..7f9f597d6 100644 --- a/tests/mcmc/test_uturn.py +++ b/tests/mcmc/test_uturn.py @@ -20,7 +20,7 @@ class UTurnTest(chex.TestCase): ) def test_is_iterative_turning(self, checkpoint_idxs, expected_turning): inverse_mass_matrix = jnp.ones(1) - _, _, is_turning = gaussian_euclidean(inverse_mass_matrix) + _, _, is_turning, _ = gaussian_euclidean(inverse_mass_matrix) _, _, is_iterative_turning = iterative_uturn_numpyro(is_turning) momentum = 1.0