diff --git a/blackjax/kernels.py b/blackjax/kernels.py index 21618ae69..1518abc57 100644 --- a/blackjax/kernels.py +++ b/blackjax/kernels.py @@ -1026,7 +1026,7 @@ def init_fn(position: PyTree): return cls.init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): - return kernel(rng_key, random_step, state, logdensity_fn) + return kernel(rng_key, state, logdensity_fn, random_step) return MCMCSamplingAlgorithm(init_fn, step_fn) @@ -1085,9 +1085,9 @@ def init_fn(position: PyTree): def step_fn(rng_key: PRNGKey, state): return kernel( rng_key, + state, logdensity_fn, proposal_generator, - state, proposal_logdensity_fn, ) @@ -1143,7 +1143,7 @@ def init_fn(position: PyTree): return cls.init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): - return kernel(rng_key, proposal_distribution, state, logdensity_fn) + return kernel(rng_key, state, logdensity_fn, proposal_distribution) return MCMCSamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/random_walk.py b/blackjax/mcmc/random_walk.py index 7f8009723..f0cd40fc2 100644 --- a/blackjax/mcmc/random_walk.py +++ b/blackjax/mcmc/random_walk.py @@ -159,7 +159,7 @@ def build_additive_step(): """ def kernel( - rng_key: PRNGKey, random_step: Callable, state: RWState, logdensity_fn: Callable + rng_key: PRNGKey, state: RWState, logdensity_fn: Callable, random_step: Callable ) -> Tuple[RWState, RWInfo]: def proposal_generator(key_proposal, position): move_proposal = random_step(key_proposal, position) @@ -167,7 +167,7 @@ def proposal_generator(key_proposal, position): return new_position inner_kernel = build_rmh() - return inner_kernel(rng_key, logdensity_fn, proposal_generator, state) + return inner_kernel(rng_key, state, logdensity_fn, proposal_generator) return kernel @@ -187,9 +187,9 @@ def build_irmh() -> Callable: def kernel( rng_key: PRNGKey, - proposal_distribution: Callable, state: RWState, logdensity_fn: Callable, + proposal_distribution: Callable, ) -> Tuple[RWState, RWInfo]: """ Parameters @@ -203,7 +203,7 @@ def proposal_generator(rng_key: PRNGKey, position: PyTree): return proposal_distribution(rng_key) inner_kernel = build_rmh() - return inner_kernel(rng_key, logdensity_fn, proposal_generator, state) + return inner_kernel(rng_key, state, logdensity_fn, proposal_generator) return kernel @@ -220,9 +220,9 @@ def build_rmh(): def kernel( rng_key: PRNGKey, + state: RWState, logdensity_fn: Callable, transition_generator: Callable, - state: RWState, proposal_logdensity_fn: Optional[Callable] = None, ) -> Tuple[RWState, RWInfo]: """Move the chain by one step using the Rosenbluth Metropolis Hastings diff --git a/tests/mcmc/test_random_walk_without_chex.py b/tests/mcmc/test_random_walk_without_chex.py index e0e7bee98..7ae431fa4 100644 --- a/tests/mcmc/test_random_walk_without_chex.py +++ b/tests/mcmc/test_random_walk_without_chex.py @@ -38,9 +38,9 @@ def test_logdensity_accepts(position): new_state, _ = step( rng_key, - random_step, RWState(position=initial_position, logdensity=1.0), test_logdensity_accepts, + random_step, ) np.testing.assert_allclose(new_state.position, jnp.array([60.0])) @@ -69,9 +69,9 @@ def test_logdensity_accepts(position): for previous_position in [initial_position, other_position]: new_state, _ = step( rng_key, - proposal_distribution, RWState(position=previous_position, logdensity=1.0), test_logdensity_accepts, + proposal_distribution, ) np.testing.assert_allclose(new_state.position, jnp.array([10.0])) diff --git a/tests/smc/test_kernel_compatibility.py b/tests/smc/test_kernel_compatibility.py new file mode 100644 index 000000000..1fd880ec6 --- /dev/null +++ b/tests/smc/test_kernel_compatibility.py @@ -0,0 +1,108 @@ +import unittest + +import jax +from jax import numpy as jnp +from jax.scipy.stats import multivariate_normal + +import blackjax +from blackjax import adaptive_tempered_smc +from blackjax.mcmc.random_walk import normal + + +class SMCAndMCMCIntegrationTest(unittest.TestCase): + """ + An integration test that verifies which MCMC can be used as + SMC mutation step kernels. + """ + + def setUp(self): + super().setUp() + self.key = jax.random.PRNGKey(42) + self.initial_particles = jax.random.multivariate_normal( + self.key, jnp.zeros(2), jnp.eye(2), (3,) + ) + + def check_compatible(self, mcmc_step_fn, mcmc_init_fn, mcmc_parameters): + """ + Runs one SMC step + """ + init, kernel = adaptive_tempered_smc( + self.prior_log_prob, + self.loglikelihood, + mcmc_step_fn, + mcmc_init_fn, + mcmc_parameters=mcmc_parameters, + resampling_fn=self.resampling_fn, + target_ess=0.5, + root_solver=self.root_solver, + num_mcmc_steps=1, + ) + kernel(self.key, init(self.initial_particles)) + + def test_compatible_with_rwm(self): + self.check_compatible( + blackjax.additive_step_random_walk.build_kernel(), + blackjax.additive_step_random_walk.init, + {"random_step": normal(1.0)}, + ) + + def test_compatible_with_rmh(self): + self.check_compatible( + blackjax.rmh.build_kernel(), + blackjax.rmh.init, + { + "transition_generator": lambda a, b: blackjax.mcmc.random_walk.normal( + 1.0 + )(a, b) + }, + ) + + def test_compatible_with_hmc(self): + self.check_compatible( + blackjax.hmc.kernel(), + blackjax.hmc.init, + { + "step_size": 0.3, + "inverse_mass_matrix": jnp.array([1]), + "num_integration_steps": 1, + }, + ) + + def test_compatible_with_irmh(self): + self.check_compatible( + blackjax.irmh.build_kernel(), + blackjax.irmh.init, + { + "proposal_distribution": lambda key: jnp.array([1.0, 1.0]) + + jax.random.normal(key) + }, + ) + + def test_compatible_with_nuts(self): + self.check_compatible( + blackjax.nuts.kernel(), + blackjax.nuts.init, + {"step_size": 1e-10, "inverse_mass_matrix": jnp.eye(2)}, + ) + + def test_compatible_with_mala(self): + self.check_compatible( + blackjax.mala.kernel(), blackjax.mala.init, {"step_size": 1e-10} + ) + + @staticmethod + def prior_log_prob(x): + d = x.shape[0] + return multivariate_normal.logpdf(x, jnp.zeros((d,)), jnp.eye(d)) + + @staticmethod + def loglikelihood(x): + return -5 * jnp.sum(jnp.square(x**2 - 1)) + + @staticmethod + def root_solver(fun, _delta0, min_delta, max_delta, eps=1e-4, max_iter=100): + return 0.8 + + @staticmethod + def resampling_fn(rng_key, weights: jax.Array, num_samples: int): + return jnp.array([0, 1, 2])