Skip to content

Commit

Permalink
Refactor the base SMC kernel
Browse files Browse the repository at this point in the history
Base SMC is neatly divided in 3 steps:
- particle update
- particle weighting
- resampling
  • Loading branch information
rlouf committed Dec 20, 2022
1 parent 8069c1e commit a8ff4a4
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion blackjax/smc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import jax
import jax.numpy as jnp

from blackjax.types import PyTree
from blackjax.types import PRNGKey, PyTree


class SMCInfo(NamedTuple):
Expand All @@ -40,6 +40,41 @@ class SMCInfo(NamedTuple):
log_likelihood_increment: float


def base(
rng_key: PRNGKey,
particles: PyTree,
update: Callable,
weigh: Callable,
resample: Callable,
):
"""General SMC sampling step.
rng_key
particles
update
weigh
resample
"""

updating_key, resampling_key = jax.random.split(rng_key, 2)

num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
keys = jax.random.split(updating_key, num_particles)
particles, update_info = update(keys, particles)

weights = weigh(particles)
weights, logp_increments = normalize(weights)
# Here normalize the weights and compute log_increments

resampling_idx = resample(weights, resampling_key)
particles = jax.tree_map(lambda x: x[resampling_idx], particles)

# class TemperedSMCInfo(NamedTuple):
# lambda: float
# smc_info: SMCInfo
return particles, SMCInfo(weights, resampling_idx, logp_increments, update_info)


def kernel(
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
Expand Down

0 comments on commit a8ff4a4

Please sign in to comment.