From aae442b8d32c02a3ab718391718d32e11e25bbd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 16 Sep 2022 21:30:48 +0200 Subject: [PATCH] Remove `kernel_factory` from SMC base kernel --- blackjax/smc/base.py | 45 ++++++++++++++++++++++++++------------------ tests/test_smc.py | 22 ++++++++++++---------- 2 files changed, 39 insertions(+), 28 deletions(-) diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index 9db15809c..e3ba53be9 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -1,3 +1,4 @@ +import functools as ft from typing import Callable, NamedTuple, Tuple import jax @@ -27,18 +28,18 @@ class SMCInfo(NamedTuple): def kernel( - mcmc_kernel_factory: Callable, - mcmc_state_generator: Callable, + mcmc_kernel: Callable, + mcmc_init: Callable, resampling_fn: Callable, - num_mcmc_iterations: int, + num_mcmc_steps: int, ): """Build a generic SMC kernel. In Feynman-Kac equivalent terms, the algo goes roughly as follows: ``` - M_t = mcmc_kernel_factory(potential_fn) - for i in range(num_mcmc_iterations): + M_t = mcmc_kernel(logprob_fn, **parameters) + for i in range(num_mcmc_steps): x_t^i = M_t(..., x_t^i) G_t = log_weights_fn log_weights = G_t(x_t) @@ -49,14 +50,14 @@ def kernel( Parameters ---------- - mcmc_kernel_factory: Callable - A function of the Markov potential that returns a mcmc_kernel. - mcmc_state_generator: Callable - A function that creates a new mcmc state from a position and a potential. + mcmc_kernel: Callable + A MCMC kernel that generates a new sample from a give state. + mcmc_init: Callable + Creates a new MCMC state from a position. resampling_fn: Callable A function that resamples the particles generated by the MCMC kernel, based of previously computed weights. - num_mcmc_iterations: int + num_mcmc_steps: int Number of iterations of the MCMC kernel Returns @@ -72,9 +73,18 @@ def one_step( particles: PyTree, logprob_fn: Callable, log_weight_fn: Callable, + mcmc_parameters: dict, ) -> Tuple[PyTree, SMCInfo]: """ + We could write this in a much better way? + + particles = vmap(f)(particles, *parameters) + weights = vmap(logweightfn)(weights, particles) + resampled_weights = sample(particles, weights) + + Plus the problem is that you may want to parallelize in a different way later + Parameters ---------- rng_key: DeviceArray[int], @@ -85,6 +95,8 @@ def one_step( Log probability function we wish to sample from. log_weight_fn: Callable A function that represents the Feynman-Kac log potential at time t. + mcmc_parameters: dict + A dictionary that contains the parameters of the MCMC kernel. Returns ------- @@ -97,20 +109,17 @@ def one_step( num_particles = jax.tree_flatten(particles)[0][0].shape[0] scan_key, resampling_key = jax.random.split(rng_key, 2) - # First advance the particles using the MCMC kernel - mcmc_kernel = mcmc_kernel_factory(logprob_fn) + applied_mcmc_kernel = ft.partial(mcmc_kernel, **mcmc_parameters) def mcmc_body_fn(curr_particles, curr_key): keys = jax.random.split(curr_key, num_particles) - new_particles, _ = jax.vmap(mcmc_kernel, in_axes=(0, 0))( - keys, curr_particles + new_particles, _ = jax.vmap(applied_mcmc_kernel, in_axes=(0, 0, None))( + keys, curr_particles, logprob_fn ) return new_particles, None - mcmc_state = jax.vmap(mcmc_state_generator, in_axes=(0, None))( - particles, logprob_fn - ) - keys = jax.random.split(scan_key, num_mcmc_iterations) + mcmc_state = jax.vmap(mcmc_init, in_axes=(0, None))(particles, logprob_fn) + keys = jax.random.split(scan_key, num_mcmc_steps) proposed_states, _ = jax.lax.scan(mcmc_body_fn, mcmc_state, keys) proposed_particles = proposed_states.position diff --git a/tests/test_smc.py b/tests/test_smc.py index 16f0d569d..0f9a5b830 100644 --- a/tests/test_smc.py +++ b/tests/test_smc.py @@ -30,17 +30,19 @@ def setUp(self): @parameterized.parameters([500, 1000, 5000]) def test_smc(self, N): - mcmc_factory = lambda logprob_fn: blackjax.hmc( - logprob_fn, - step_size=1e-2, - inverse_mass_matrix=jnp.eye(1), - num_integration_steps=50, - ).step + mcmc_parameters = { + "step_size": 1e-2, + "inverse_mass_matrix": jnp.eye(1), + "num_integration_steps": 50, + } specialized_log_weights_fn = lambda tree: log_weights_fn(tree, 1.0) - kernel = base.kernel( - mcmc_factory, blackjax.mcmc.hmc.init, resampling.systematic, 1000 + smc_kernel = base.kernel( + blackjax.mcmc.hmc.kernel(), + blackjax.mcmc.hmc.init, + resampling.systematic, + 1000, ) # Don't use exactly the invariant distribution for the MCMC kernel @@ -48,15 +50,15 @@ def test_smc(self, N): updated_particles, _ = self.variant( functools.partial( - kernel, + smc_kernel, logprob_fn=kernel_logprob_fn, log_weight_fn=specialized_log_weights_fn, + mcmc_parameters=mcmc_parameters, ) )(self.key, init_particles) expected_mean = 0.5 expected_std = np.sqrt(0.5) - np.testing.assert_allclose( expected_mean, updated_particles.mean(), rtol=1e-2, atol=1e-1 )