Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the SMC kernels #279

Merged
merged 3 commits into from
Jan 17, 2023
Merged

Refactor the SMC kernels #279

merged 3 commits into from
Jan 17, 2023

Conversation

rlouf
Copy link
Member

@rlouf rlouf commented Sep 16, 2022

There are a few things that I find unsatisfactory with the SMC base kernel:

  1. The use of kernel_factory, which prevents the users from passing different parameters at different iterations. It forces adaptation to happen inside the SMC kernel, which is something we don't want. We should stick to the sample, then update the parameters paradigm sketched in Refactor the adaptation kernels #276
  2. The internals of the base kernel could be better organised, following this pseudo code:
new_particles, weights = vectorized_fn(particles)  # scatter computation
sampled_particles = sample_fn(new_particles, weights)  # gatherresults

I think this decomposition is what was conceptually missing to properly integrate #117

  1. The current code uses vmap, but it should be possible to use pmap as wellj
  2. num_mcmc_iterations, mcmc_iter, etc. naming is inconsistent.

The implementation of the SMC base kernel should be based on the formalism exposed in this book.

@rlouf
Copy link
Member Author

rlouf commented Oct 7, 2022

A few thoughts:

  • Base SMC should be implemented in the very general form where smc_step_fn(rng_key, particles, particle_update_fn, particle_weighing_fn, resampling_fn) where particle_update_fn and particle_weighing_fn are batched using vmap or pmap
  • This entails a nice hierarchy for MCMC: mcmc_step_fn < inference_loop < SMC(inference_loop)
  • Batched MCMC is a special case of SMC with no resampling step.
  • A sampling loop for MCMC is a special case of SMC with one particle and no resampling step

Tentative design for SMC with MCMC steps

import jax
import blackjax 

logprob_fn: Callable
mcmc_init: Callable
mcmc_step: Callable
num_mcmc_steps: int
mcmc_parameters: Dict

def update_particle(rng_key, position):
    
    def one_step(state, rng_key):
        # This can contain *anything* not 
        # just a MCMC kernel
        #
        # We can even do adaptation here if we 
        # plug-in the states correctly between here and SMC
        # particles would need to be `state`
        # and we'd need to pass info 
        state, _ = mcmc_step(rng_key, state, logprob_fn, **mcmc_parameters)
        return state, state

    keys = jax.random.split(rng_key, num_mcmc_steps)
    state = mcmc_init(position, logprob_fn)
    last_state, states = jax.lax.scan(one_step, state, keys)
     # can be a while loop 🔁 for adaptive schemes

    return last_state.position
    # Waste-free version 🚯
    # return states.position
    # Can be nicely combined with progressive HMC sampling

# This could also be a SMC step 🙃
update = jax.vmap(update_particle)
weigh = jax.vmap(logprob_fn)
resample = blackjax.smc.resampling.stratified 

new_particles, info = smc.step(
    rng_key,
    particles,
    update,
    weigh,
    resample,
)

And smc.step is doing some basic plumbing, and is here to enforce this general structure. This form allows parameter adaptation for the MCMC kernels.

How general is this design?

  • Check how robust design is to $SMC^2$
  • Make sure Waste-free SMC fits
  • Show an example with tuning MCMC with SMC (new adaptation algorithm)
  • Check how robust design is to $SMC-ABC$
  • Implement Ensemble MCMC with this.
  • How does particle cascade fit in this?

@codecov
Copy link

codecov bot commented Oct 9, 2022

Codecov Report

Merging #279 (ac9f75d) into main (a4ee853) will decrease coverage by 0.01%.
The diff coverage is 100.00%.

❗ Current head ac9f75d differs from pull request most recent head c7e1f5e. Consider uploading reports for the commit c7e1f5e to get more accurate results

@@            Coverage Diff             @@
##             main     #279      +/-   ##
==========================================
- Coverage   99.16%   99.16%   -0.01%     
==========================================
  Files          48       48              
  Lines        1923     1919       -4     
==========================================
- Hits         1907     1903       -4     
  Misses         16       16              
Impacted Files Coverage Δ
blackjax/kernels.py 99.19% <ø> (-0.02%) ⬇️
blackjax/smc/adaptive_tempered.py 100.00% <100.00%> (ø)
blackjax/smc/base.py 100.00% <100.00%> (ø)
blackjax/smc/ess.py 100.00% <100.00%> (ø)
blackjax/smc/resampling.py 100.00% <100.00%> (ø)
blackjax/smc/tempered.py 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@rlouf rlouf mentioned this pull request Oct 24, 2022
@rlouf rlouf added enhancement New feature or request sampler Issue related to samplers refactoring Change that adds no functionality but improves code quality smc Sequential Monte Carlo samplers labels Oct 27, 2022
@rlouf rlouf mentioned this pull request Dec 20, 2022
12 tasks
@rlouf rlouf force-pushed the refactor-smc branch 3 times, most recently from 242102b to 79c4730 Compare December 22, 2022 05:05
@rlouf
Copy link
Member Author

rlouf commented Dec 22, 2022

@AdrienCorenflos could you please take a look? (especially blackjax.smc.step and blackjax.smc.tempered to see how the base is used)?

Left in this PR:

  • Further simplify the Tempered SMC tests
  • Refactor the ESS tests
  • Use log for ESS by default (and logsumexp for numerical stability)
  • Make sure that the "Waste-free" version works

@rlouf rlouf marked this pull request as ready for review December 22, 2022 05:09
@rlouf rlouf requested a review from junpenglao December 22, 2022 05:09
@rlouf rlouf force-pushed the refactor-smc branch 4 times, most recently from a46d762 to 412661b Compare January 13, 2023 14:21
@rlouf
Copy link
Member Author

rlouf commented Jan 13, 2023

I am about done with this. The new interface allows to work with the Waste-Free version of SMC, one just has to design the appropriate "update" function and pass the desired number of samples extracted during resampling. I added a test that demonstrates this.

I am waiting for #441 to decide whether we move all the content of smc to a new /meta directory and rename base.py -> smc.py.

This was linked to issues Jan 13, 2023
@rlouf
Copy link
Member Author

rlouf commented Jan 16, 2023

This is ready for review. We'll do the repo re-organisation in a separate PR so it is easily reversible.

rlouf and others added 2 commits January 17, 2023 14:46
The number of particles we resample in the SMC step is currently equal
to the number of weights passed to the resampling function. In this PR
we allow the caller to ask for a different number of particles. This
allows to build Waste-Free SMC kernels by asking to resample M < N
particles and build the update function so that it returns N particles.
@rlouf rlouf force-pushed the refactor-smc branch 2 times, most recently from ac9f75d to c7e1f5e Compare January 17, 2023 14:20
@rlouf rlouf merged commit ca6c46a into blackjax-devs:main Jan 17, 2023
@rlouf rlouf deleted the refactor-smc branch January 17, 2023 14:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request refactoring Change that adds no functionality but improves code quality sampler Issue related to samplers smc Sequential Monte Carlo samplers
Projects
None yet
1 participant