-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable shared mcmc parameters with tempered smc #694
Conversation
blackjax/smc/tempered.py
Outdated
shared_mcmc_parameters = { | ||
k: v[0, ...] for k, v in mcmc_parameters.items() if v.shape[0] == 1 | ||
} | ||
unshared_mcmc_parameters = { | ||
k: v for k, v in mcmc_parameters.items() if v.shape[0] != 1 | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please put them in a single for loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall - is it still draft?
The draft was mostly for if the tests should be modified - removed it |
@@ -281,7 +281,7 @@ def test_with_adaptive_tempered(self): | |||
|
|||
def parameter_update(state, info): | |||
return extend_params( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
calling extend params shouldn't be needed anymore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still using extend_params since it needs to convert to an array and add a leading dimension of length one. I could have added a separate helper but it'd do the same thing as extend_params but just without the jnp.repeat part
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is worth to modify extend_params
into something simple without the jnp.repeat
.
Would this work?
def extend_params(params: ArrayLikeTree):
return jax.tree.map(lambda a: jnp.expand_dims(a, 0), params)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would work, perhaps with potentially a jnp.asarray()
included. The concern I have with modifying extend_params
is I'd assume there's still a use case for it with initializing unshared parameters?
A current use case is for the remaining tests that are using duplicated parameters for testing the unshared case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ciguaran Is there any other use case besides dimension matching initially?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's none.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd assume there's still a use case for it with initializing unshared parameters -> that is a good concern although I think that in most cases we wouldn't initialize by "extending" with copies. We would do it by sampling from some probability distribution. So you can modify it as Junpeng suggested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, that makes sense. thanks
tests/smc/test_tempered_smc.py
Outdated
}, | ||
hmc_parameters_list = [ | ||
extend_params( | ||
num_particles if extend else 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd add a comment highlighting that here you are testing that extending with a copy is the same as having a 1 dimension parameter. Otherwise it may be tricky to get why this is even hapening.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
* Update README.md (blackjax-devs#638) * Update README.md Update citation. * Update README.md * Indexing the notebook showing how to reproduce the GIF. (blackjax-devs#640) Co-authored-by: Junpeng Lao <[email protected]> * Bump python version (blackjax-devs#645) * Bump python version * update bool inverse * SMC: allow each mutation kernel to have different parameters. (blackjax-devs#649) * vmaping over parameters in base * switch from mcmc_factory to just passing in parameters * pre-commit and typing * CRU and docs improvement * pre-commit * code review updates * pre-commit * rename test * Migrate from deprecated `host_callback` to `io_callback` (blackjax-devs#651) * Migrate from deprecated `host_callback` to `io_callback` Co-Authored-By: George Necula <[email protected]> * Format file * Fix bug * Fix MALA transition energy (blackjax-devs#653) * Fix MALA transition energy * Use a different logic. * Change variable names (blackjax-devs#654) * Replace iterative RNG split and carry with `jax.random.fold_in` (blackjax-devs#656) * Replace iterative RNG split and carry with `jax.random.fold_in` * revert unintended change * file formatting * change `jax.tree_map` to `jax.tree.map` * revert unintended file * fiddle with rng_key * seed again * Removal of Algorithm classes. (blackjax-devs#657) * more * removing export * removal of classes, tests passing * linter * fix on test * linter * removing parametrization on test * code review updates * exporting as_top_level_api in dynamic_hmc * linter * code review update: replace imports * Fix deprecated call to jnp.clip (blackjax-devs#664) * Update jax version requirements (blackjax-devs#666) Fix blackjax-devs#665 * Make tests pass on `aarch64-linux` (blackjax-devs#671) * Enable fitlering of AdaptationInfo (blackjax-devs#674) * enable AdaptationInfo filtering * revert progress_bar * fix pre-commit * fix empty sets * enable adapt info filtering for all adaptation algorithms * fix precommit /progressbar=True * change filter tuple to use tree_map * Update `run_inference_algorithm` to split `initial_position` and `initial_state` (blackjax-devs#672) * UPDATE DOCSTRING * ADD STREAMING VERSION * UPDATE TESTS * ADD DOCSTRING * ADD TEST * REFACTOR RUN_INFERENCE_ALGORITHM * UPDATE DOCSTRING * Precommit * CLEAN TESTS * ADD INITIAL_POSITION * FIX TEST * RENAME O * FIX DOCSTRING * PUT EXPECTATION AFTER TRANSFORM * Preconditioned mclmc (blackjax-devs#673) * TESTS * TESTS * UPDATE DOCSTRING * ADD STREAMING VERSION * ADD PRECONDITIONING TO MCLMC * ADD PRECONDITIONING TO TUNING FOR MCLMC * UPDATE GITIGNORE * UPDATE GITIGNORE * UPDATE TESTS * UPDATE TESTS * ADD DOCSTRING * ADD TEST * STREAMING AVERAGE * ADD TEST * REFACTOR RUN_INFERENCE_ALGORITHM * UPDATE DOCSTRING * Precommit * CLEAN TESTS * GITIGNORE * PRECOMMIT CLEAN UP * ADD INITIAL_POSITION * FIX TEST * ADD TEST * REMOVE BENCHMARKS * BUG FIX * CHANGE PRECISION * CHANGE PRECISION * RENAME O * UPDATE STREAMING AVG * UPDATE PR * RENAME STD_MAT * New integrator, and add some metadata to integrators.py (blackjax-devs#681) * TESTS * TESTS * UPDATE DOCSTRING * ADD STREAMING VERSION * ADD PRECONDITIONING TO MCLMC * ADD PRECONDITIONING TO TUNING FOR MCLMC * UPDATE GITIGNORE * UPDATE GITIGNORE * UPDATE TESTS * UPDATE TESTS * ADD DOCSTRING * ADD TEST * STREAMING AVERAGE * ADD TEST * REFACTOR RUN_INFERENCE_ALGORITHM * UPDATE DOCSTRING * Precommit * CLEAN TESTS * GITIGNORE * PRECOMMIT CLEAN UP * FIX SPELLING, ADD OMELYAN, EXPORT COEFFICIENTS * TEMPORARILY ADD BENCHMARKS * ADD INITIAL_POSITION * FIX TEST * CLEAN UP * REMOVE BENCHMARKS * ADD TEST * REMOVE BENCHMARKS * BUG FIX * CHANGE PRECISION * CHANGE PRECISION * ADD OMELYAN TEST * RENAME O * UPDATE STREAMING AVG * UPDATE PR * RENAME STD_MAT * MERGE MAIN * REMOVE COEFFICIENT EXPORTS * Minor formatting (blackjax-devs#685) * Minor formatting * formatting * fix test * formatting * MAKE WINDOW ADAPTATION TAKE INTEGRATOR AS ARGUMENT (blackjax-devs#687) * FIX KWARG BUG (blackjax-devs#686) * FIX KWARG BUG * FIX KWARG BUG * Change isokinetic_integrator generation API (blackjax-devs#689) * Apply function on pytree directly. (blackjax-devs#692) * Apply function on pytree directly. Avoiding unnecssary unpacking * Fix kwarg * Fix sampling test. (blackjax-devs#693) * Enable shared mcmc parameters with tempered smc (blackjax-devs#694) * add parameter filtering * fix parameter split + docstring * change extend_paramss * convert to bit twiddling (blackjax-devs#696) * Remove nightly release (blackjax-devs#699) * Fix doc mistakes (blackjax-devs#701) * Fix equation formatting * Clarify JAX gradient error * Fix punctuation + capitalization * Fix grammar Should not begin sentence with "i.e." in English. * Fix math formatting error * Fix typo Change parallel _ensample_ chain adaptation to parallel _ensemble_ chain adaptation. * Add SVGD citation to appear in doc Currently the SVGD paper is only cited in the `kernel` function, which is defined _within_ the `build_kernel` function. Because of this nested function format, the SVGD paper is _not_ cited in the documentation. To fix this, I added a citation to the SVGD paper in the `as_top_level_api` docstring. * Fix grammar + clarify doc * Fix typo --------- Co-authored-by: Junpeng Lao <[email protected]> * Update index.md (blackjax-devs#711) The jitted step remained unused, leading to the example running with an uncompiled nuts.step. Changing this reduces the execution time by a factor of 30 on my system and showcases blackjax' speed. * Enable progress bar under pmap (blackjax-devs#712) * enable pmap progbar * fix bar creation * add locking * fix formatting * switch to using chain state * remove labels (blackjax-devs#716) * Simplify `run_inference_algorithm` (blackjax-devs#714) * fix minor type errors * storing only expectation values * fixed memory efficient sampling * clean up * renaming vars * precommit fixes * fixing tests * fixing tests * fixing tests * fixing tests * fixing tests * merge main * burn in and fix tests * burn in and fix tests * minor fixes * minor fixes * minor fixes --------- Co-authored-by: [email protected] <[email protected]> * Harmonize Quickstart example (blackjax-devs#717) * Update README.md (blackjax-devs#719) --------- Co-authored-by: Junpeng Lao <[email protected]> Co-authored-by: Carlos Iguaran <[email protected]> Co-authored-by: ksnxr <[email protected]> Co-authored-by: Gaétan Lepage <[email protected]> Co-authored-by: Alberto Cabezas <[email protected]> Co-authored-by: andrewdipper <[email protected]> Co-authored-by: Reuben <[email protected]> Co-authored-by: Gilad Turok <[email protected]> Co-authored-by: johannahaffner <[email protected]> Co-authored-by: [email protected] <[email protected]>
Allows for
mcmc_parameters
to be passed to the mcmc kernel as shared parameters prior to applying vmap. Thus shared parameters will not need to be duplicated for each individual particle.This change filters
mcmc_parameters
by the length of the first dimension. Any parameters with length 1 are considered shared (note this is also acceptable in the case of just a single particle) and the rest are unshared. Shared parameters are then closed over before applying vmap so they don't need duplication. The behavior remains the same for shared parameters that are duplicated as they are just treated as unshared as before. This seems like the most reasonable way to handle shared parameters but let me know.Related to #690
cc @ciguaran
Some notes: