Skip to content

Commit

Permalink
Remove kernel_factory from SMC base kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Sep 16, 2022
1 parent 82c2f1c commit aae442b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 28 deletions.
45 changes: 27 additions & 18 deletions blackjax/smc/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools as ft
from typing import Callable, NamedTuple, Tuple

import jax
Expand Down Expand Up @@ -27,18 +28,18 @@ class SMCInfo(NamedTuple):


def kernel(
mcmc_kernel_factory: Callable,
mcmc_state_generator: Callable,
mcmc_kernel: Callable,
mcmc_init: Callable,
resampling_fn: Callable,
num_mcmc_iterations: int,
num_mcmc_steps: int,
):
"""Build a generic SMC kernel.
In Feynman-Kac equivalent terms, the algo goes roughly as follows:
```
M_t = mcmc_kernel_factory(potential_fn)
for i in range(num_mcmc_iterations):
M_t = mcmc_kernel(logprob_fn, **parameters)
for i in range(num_mcmc_steps):
x_t^i = M_t(..., x_t^i)
G_t = log_weights_fn
log_weights = G_t(x_t)
Expand All @@ -49,14 +50,14 @@ def kernel(
Parameters
----------
mcmc_kernel_factory: Callable
A function of the Markov potential that returns a mcmc_kernel.
mcmc_state_generator: Callable
A function that creates a new mcmc state from a position and a potential.
mcmc_kernel: Callable
A MCMC kernel that generates a new sample from a give state.
mcmc_init: Callable
Creates a new MCMC state from a position.
resampling_fn: Callable
A function that resamples the particles generated by the MCMC kernel,
based of previously computed weights.
num_mcmc_iterations: int
num_mcmc_steps: int
Number of iterations of the MCMC kernel
Returns
Expand All @@ -72,9 +73,18 @@ def one_step(
particles: PyTree,
logprob_fn: Callable,
log_weight_fn: Callable,
mcmc_parameters: dict,
) -> Tuple[PyTree, SMCInfo]:
"""
We could write this in a much better way?
particles = vmap(f)(particles, *parameters)
weights = vmap(logweightfn)(weights, particles)
resampled_weights = sample(particles, weights)
Plus the problem is that you may want to parallelize in a different way later
Parameters
----------
rng_key: DeviceArray[int],
Expand All @@ -85,6 +95,8 @@ def one_step(
Log probability function we wish to sample from.
log_weight_fn: Callable
A function that represents the Feynman-Kac log potential at time t.
mcmc_parameters: dict
A dictionary that contains the parameters of the MCMC kernel.
Returns
-------
Expand All @@ -97,20 +109,17 @@ def one_step(
num_particles = jax.tree_flatten(particles)[0][0].shape[0]
scan_key, resampling_key = jax.random.split(rng_key, 2)

# First advance the particles using the MCMC kernel
mcmc_kernel = mcmc_kernel_factory(logprob_fn)
applied_mcmc_kernel = ft.partial(mcmc_kernel, **mcmc_parameters)

def mcmc_body_fn(curr_particles, curr_key):
keys = jax.random.split(curr_key, num_particles)
new_particles, _ = jax.vmap(mcmc_kernel, in_axes=(0, 0))(
keys, curr_particles
new_particles, _ = jax.vmap(applied_mcmc_kernel, in_axes=(0, 0, None))(
keys, curr_particles, logprob_fn
)
return new_particles, None

mcmc_state = jax.vmap(mcmc_state_generator, in_axes=(0, None))(
particles, logprob_fn
)
keys = jax.random.split(scan_key, num_mcmc_iterations)
mcmc_state = jax.vmap(mcmc_init, in_axes=(0, None))(particles, logprob_fn)
keys = jax.random.split(scan_key, num_mcmc_steps)
proposed_states, _ = jax.lax.scan(mcmc_body_fn, mcmc_state, keys)
proposed_particles = proposed_states.position

Expand Down
22 changes: 12 additions & 10 deletions tests/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,33 +30,35 @@ def setUp(self):
@parameterized.parameters([500, 1000, 5000])
def test_smc(self, N):

mcmc_factory = lambda logprob_fn: blackjax.hmc(
logprob_fn,
step_size=1e-2,
inverse_mass_matrix=jnp.eye(1),
num_integration_steps=50,
).step
mcmc_parameters = {
"step_size": 1e-2,
"inverse_mass_matrix": jnp.eye(1),
"num_integration_steps": 50,
}

specialized_log_weights_fn = lambda tree: log_weights_fn(tree, 1.0)

kernel = base.kernel(
mcmc_factory, blackjax.mcmc.hmc.init, resampling.systematic, 1000
smc_kernel = base.kernel(
blackjax.mcmc.hmc.kernel(),
blackjax.mcmc.hmc.init,
resampling.systematic,
1000,
)

# Don't use exactly the invariant distribution for the MCMC kernel
init_particles = 0.25 + np.random.randn(N)

updated_particles, _ = self.variant(
functools.partial(
kernel,
smc_kernel,
logprob_fn=kernel_logprob_fn,
log_weight_fn=specialized_log_weights_fn,
mcmc_parameters=mcmc_parameters,
)
)(self.key, init_particles)

expected_mean = 0.5
expected_std = np.sqrt(0.5)

np.testing.assert_allclose(
expected_mean, updated_particles.mean(), rtol=1e-2, atol=1e-1
)
Expand Down

0 comments on commit aae442b

Please sign in to comment.