diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index 2e9289369..d649a5442 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -16,7 +16,7 @@ import jax import jax.numpy as jnp -from blackjax.types import PyTree +from blackjax.types import PRNGKey, PyTree class SMCInfo(NamedTuple): @@ -40,6 +40,41 @@ class SMCInfo(NamedTuple): log_likelihood_increment: float +def base( + rng_key: PRNGKey, + particles: PyTree, + update: Callable, + weigh: Callable, + resample: Callable, +): + """General SMC sampling step. + + rng_key + particles + update + weigh + resample + """ + + updating_key, resampling_key = jax.random.split(rng_key, 2) + + num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] + keys = jax.random.split(updating_key, num_particles) + particles, update_info = update(keys, particles) + + weights = weigh(particles) + weights, logp_increments = normalize(weights) + # Here normalize the weights and compute log_increments + + resampling_idx = resample(weights, resampling_key) + particles = jax.tree_map(lambda x: x[resampling_idx], particles) + + # class TemperedSMCInfo(NamedTuple): + # lambda: float + # smc_info: SMCInfo + return particles, SMCInfo(weights, resampling_idx, logp_increments, update_info) + + def kernel( mcmc_step_fn: Callable, mcmc_init_fn: Callable,