Skip to content

Commit

Permalink
Refactor the base SMC step function
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Dec 22, 2022
1 parent cc5d3f5 commit 242102b
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 131 deletions.
28 changes: 14 additions & 14 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,25 @@ def __new__( # type: ignore[misc]
cls,
logprior_fn: Callable,
loglikelihood_fn: Callable,
mcmc_algorithm: MCMCSamplingAlgorithm,
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
mcmc_parameters: Dict,
resampling_fn: Callable,
target_ess: float,
root_solver: Callable = smc.solver.dichotomy,
use_log_ess: bool = True,
mcmc_iter: int = 10,
num_mcmc_steps: int = 10,
) -> MCMCSamplingAlgorithm:
def kernel_factory(logprob_fn):
return mcmc_algorithm(logprob_fn, **mcmc_parameters).step

step = cls.kernel(
logprior_fn,
loglikelihood_fn,
kernel_factory,
mcmc_algorithm.init,
mcmc_step_fn,
mcmc_init_fn,
resampling_fn,
target_ess,
root_solver,
use_log_ess,
mcmc_iter,
)

def init_fn(position: PyTree):
Expand All @@ -95,6 +93,8 @@ def step_fn(rng_key: PRNGKey, state):
return step(
rng_key,
state,
num_mcmc_steps,
mcmc_parameters,
)

return MCMCSamplingAlgorithm(init_fn, step_fn)
Expand All @@ -117,21 +117,19 @@ def __new__( # type: ignore[misc]
cls,
logprior_fn: Callable,
loglikelihood_fn: Callable,
mcmc_algorithm: MCMCSamplingAlgorithm,
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
mcmc_parameters: Dict,
resampling_fn: Callable,
mcmc_iter: int = 10,
num_mcmc_steps: int = 10,
) -> MCMCSamplingAlgorithm:
def kernel_factory(logprob_fn):
return mcmc_algorithm(logprob_fn, **mcmc_parameters).step

step = cls.kernel(
logprior_fn,
loglikelihood_fn,
kernel_factory,
mcmc_algorithm.init,
mcmc_step_fn,
mcmc_init_fn,
resampling_fn,
mcmc_iter,
)

def init_fn(position: PyTree):
Expand All @@ -141,7 +139,9 @@ def step_fn(rng_key: PRNGKey, state, lmbda):
return step(
rng_key,
state,
num_mcmc_steps,
lmbda,
mcmc_parameters,
)

return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]
Expand Down
19 changes: 9 additions & 10 deletions blackjax/smc/adaptive_tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,12 @@
def kernel(
logprior_fn: Callable,
loglikelihood_fn: Callable,
mcmc_kernel_factory: Callable,
make_mcmc_state: Callable,
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
target_ess: float,
root_solver: Callable = solver.dichotomy,
use_log_ess: bool = True,
mcmc_iter: int = 10,
) -> Callable:
r"""Build a Tempered SMC step using an adaptive schedule.
Expand All @@ -60,8 +59,6 @@ def kernel(
use_log_ess: bool, optional
Use ESS in log space to solve for delta, default is `True`.
This is usually more stable when using gradient based solvers.
mcmc_iter: int
Number of iterations in the MCMC chain.
Returns
-------
Expand Down Expand Up @@ -89,17 +86,19 @@ def compute_delta(state: tempered.TemperedSMCState) -> float:
kernel = tempered.kernel(
logprior_fn,
loglikelihood_fn,
mcmc_kernel_factory,
make_mcmc_state,
mcmc_step_fn,
mcmc_init_fn,
resampling_fn,
mcmc_iter,
)

def one_step(
rng_key: PRNGKey, state: tempered.TemperedSMCState
rng_key: PRNGKey,
state: tempered.TemperedSMCState,
num_mcmc_steps: int,
mcmc_parameters: dict,
) -> Tuple[tempered.TemperedSMCState, base.SMCInfo]:
delta = compute_delta(state)
lmbda = delta + state.lmbda
return kernel(rng_key, state, lmbda)
return kernel(rng_key, state, num_mcmc_steps, lmbda, mcmc_parameters)

return one_step
147 changes: 116 additions & 31 deletions blackjax/smc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import jax
import jax.numpy as jnp

from blackjax.types import PyTree
from blackjax.types import PRNGKey, PyTree


class SMCInfo(NamedTuple):
Expand All @@ -35,26 +35,103 @@ class SMCInfo(NamedTuple):
"""

weights: jnp.ndarray
proposals: PyTree
ancestors: jnp.ndarray
log_likelihood_increment: float
update_info: NamedTuple


def step(
rng_key: PRNGKey,
particles: PyTree,
update_fn: Callable,
weigh_fn: Callable,
resample_fn: Callable,
):
"""General SMC sampling step.
To paraphrase [1]_, SMC samplers are particle algorithms that are able to
track a sequence of probability measures $\\matbb{P}_t(\\mathrm{d}\theta)$
linked by the recursion:
.. math::
\\mathbb{P}_{t+1}(\\mathrm{d}\theta) = \\ell_t G_t(\theta) \\mathbb{P}_t(\\mathrm{d}\theta)
We also assume that we are able to construct markov kernels $M_{t+1}$ that
leaves $\\mathbb{P}_t$ invariant. `update_fn` here corresponds to the Markov
kernel $M_{t+1}$, and `weigh_fn` corresponds to the potential function
$G_t$.
We first use `update_fn` to generate new particles from the current ones,
weigh these particles using `weigh_fn` and resample them with `resample_fn`.
The `update_fn` and `weigh_fn` functions must be batched by the called either
using `jax.vmap` or `jax.pmap`.
In Feynman-Kac terms, the algorithm goes roughly as follows:
.. code::
M_t: update_fn
G_t: weigh_fn
R_t: resample_fn
x_{t+1} = M_t(x_t)
weights = G_t(x_{t+1})
idx = R_t(weights)
Parameters
----------
rng_key
Key used to generate pseudo-random numbers.
particles
Array that contains the current positions of the particles.
update_fn
Function that takes an array of keys and particles and returns
new particles.
weigh_fn
Function that assigns a weight to the particles.
resample_fn
Function that resamples the particles.
Returns
-------
new_particles
An array that contains the new particles generated by this SMC step.
info
An `SMCInfo` object that contains extra information about the SMC
transition.
"""

updating_key, resampling_key = jax.random.split(rng_key, 2)

num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
keys = jax.random.split(updating_key, num_particles)
particles, update_info = update_fn(keys, particles)

weights = weigh_fn(particles)
weights, logp_increments = _normalize(weights)

resampling_idx = resample_fn(weights, resampling_key)
particles = jax.tree_map(lambda x: x[resampling_idx], particles)

return particles, SMCInfo(weights, resampling_idx, logp_increments, update_info)


def kernel(
mcmc_kernel_factory: Callable,
mcmc_state_generator: Callable,
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
num_mcmc_iterations: int,
):
"""Build a generic SMC kernel.
In Feynman-Kac equivalent terms, the algo goes roughly as follows:
```
M_t = mcmc_kernel_factory(potential_fn)
for i in range(num_mcmc_iterations):
x_t^i = M_t(..., x_t^i)
M_t = mcmc_kernel
G_t = log_weights_fn
for i in range(num_mcmc_steps):
x_t^i = M_t(..., x_t^i, logprob_fn, **parameters)
log_weights = G_t(x_t)
idx = resample(log_weights)
x_t = x_t[idx]
Expand All @@ -63,15 +140,13 @@ def kernel(
Parameters
----------
mcmc_kernel_factory: Callable
A function of the Markov potential that returns a mcmc_kernel.
mcmc_state_generator: Callable
A function that creates a new mcmc state from a position and a potential.
mcmc_step_fn: Callable
A MCMC step function that generates a new sample from a give state.
mcmc_init_fn: Callable
Creates a new MCMC state from a position.
resampling_fn: Callable
A function that resamples the particles generated by the MCMC kernel,
based of previously computed weights.
num_mcmc_iterations: int
Number of iterations of the MCMC kernel
Returns
-------
Expand All @@ -86,6 +161,8 @@ def one_step(
particles: PyTree,
logprob_fn: Callable,
log_weight_fn: Callable,
num_mcmc_steps: int,
mcmc_parameters: dict,
) -> Tuple[PyTree, SMCInfo]:
"""Take one step with the SMC kernel.
Expand All @@ -99,6 +176,10 @@ def one_step(
Log probability function we wish to sample from.
log_weight_fn: Callable
A function that represents the Feynman-Kac log potential at time t.
num_mcmc_steps: int
Number of iterations of the MCMC kernel
mcmc_parameters: dict
A dictionary that contains the parameters of the MCMC kernel.
Returns
-------
Expand All @@ -108,28 +189,32 @@ def one_step(
Additional information on the SMC step
"""
num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
scan_key, resampling_key = jax.random.split(rng_key, 2)
update_key, resampling_key = jax.random.split(rng_key, 2)

# First advance the particles using the MCMC kernel
mcmc_kernel = mcmc_kernel_factory(logprob_fn)
# TODO: Consider asking the caller to provide the particle_update_fn
# instead
def mcmc_update_particle(rng_key, position):
state = mcmc_init_fn(position, logprob_fn)

def mcmc_body_fn(curr_particles, curr_key):
keys = jax.random.split(curr_key, num_particles)
new_particles, _ = jax.vmap(mcmc_kernel, in_axes=(0, 0))(
keys, curr_particles
)
return new_particles, None
def body_fn(state, rng_key):
new_state, _ = mcmc_step_fn(
rng_key, state, logprob_fn, **mcmc_parameters
)
return new_state, new_state

mcmc_state = jax.vmap(mcmc_state_generator, in_axes=(0, None))(
particles, logprob_fn
)
keys = jax.random.split(scan_key, num_mcmc_iterations)
proposed_states, _ = jax.lax.scan(mcmc_body_fn, mcmc_state, keys)
proposed_particles = proposed_states.position
keys = jax.random.split(rng_key, num_mcmc_steps)
last_state, _ = jax.lax.scan(body_fn, state, keys)
return last_state.position

# Resample the particles depending on their respective weights
# Update the particles (parallel)
num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
keys = jax.random.split(update_key, num_particles)
proposed_particles = jax.vmap(mcmc_update_particle)(keys, particles)

# Compute the particles' respective weight (parallel)
log_weights = jax.vmap(log_weight_fn, in_axes=(0,))(proposed_particles)

# Resample the particles (sync)
weights, log_likelihood_increment = _normalize(log_weights)
resampling_index = resampling_fn(weights, resampling_key)
particles = jax.tree_map(lambda x: x[resampling_index], proposed_particles)
Expand Down
Loading

0 comments on commit 242102b

Please sign in to comment.