-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Mean Field Variational Inference implementation #433
Conversation
blackjax/vi/mfvi.py
Outdated
return meanfield_logprob | ||
|
||
|
||
def sample(rng_key, meanfield_param, num_samples: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you can call
Line 56 in d6801f8
def generate_gaussian_noise( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good.
But at this moment there's no internal function we can call to sample from multivariate Gaussian right (i.e. the fullrank VI case)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use the same function for multivariate Gaussian because the linear_map
util will dispatch it correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@junpenglao But how do you generate num_sample of particles using generate_gaussian_noise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right the util function does not accept that yet - let's keep your version here but could you add a TODO?
# TODO: switch to using `generate_gaussian_noise` in util.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing! That's a very good start, we'll just need to reorganise a few things.
blackjax/vi/mfvi.py
Outdated
|
||
|
||
def approximate( | ||
rng, init_params, log_prob_fn, optimizer, sample_size=5, num_steps=200 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need a default choice for the optimiser; does this assume we're using Optax?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The usual practice is to use Adam which is available in both Optax and Jaxopt. I will have a hold here till you and @junpenglao decide which optimization library to use for VI.
tests/test_mfvi.py
Outdated
import jax.scipy.stats as stats | ||
import numpy as np | ||
import blackjax | ||
import optax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think optax
is currently listed as a dependency, but jaxopt
is. We need to discuss the pros and cons of adding a new dependency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought optax
is used in SGMCMC? I think it is ok to add it as dependency.
I think that's enough comments for now :) Don't be discouraged, this is completely normal as we are also figuring out what the interface should be; there is no template as there can be for MCMC algorithms yet. |
@rlouf @junpenglao Would you mind taking another look at the API? I think the next thing we need to figure out is to avoid potential boilerplate code since the ELBO computation and optimization process is identical for most of the VI variants. |
I prefer to avoid generalising on the first implementation even though it indeed seems like we would need the equivalent of I'm taking some time off and won't be able to take a good look at the PR until Jan 3, but will have plenty of time that day to go over it in depth. |
@rlouf Well, apparently I need better documentation and more detailed test cases. Other than that, do you have any further comments on the API design? Also, I am thinking about moving forward to implementing the full-rank VI, do you think I should open up a new PR, or should I just implement it in this PR. |
I am making a pass on your implementation atm; you'll need to pull the changes once I am finished ( I now think your original proposal of implementing VI as a "kernel" was the right way to go, and we will probably need to "kernelize" (we need a better name) Pathfinder as well. I think it's nice that the user has fine-grained control over the optimization. |
I rearranged things a little, it was mostly cosmetic. I think the next step is to give it a kernel-like API, tests, docs. If the full-rank is going to share a lot of code it's probably best to do it now rather than go through the hassle of opening a new PR. |
Codecov Report
@@ Coverage Diff @@
## main #433 +/- ##
==========================================
- Coverage 99.25% 99.17% -0.08%
==========================================
Files 47 48 +1
Lines 1872 1938 +66
==========================================
+ Hits 1858 1922 +64
- Misses 14 16 +2
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
068a0b5
to
dd6438a
Compare
I turned the algorithm into a "kernel" like you originally suggested, I think it is much more in the general spirit of Blackjax and gives the user more freedom. This is almost ready to merge, we need to expand the docstring of the step function a little more, and add references. |
@rlouf Thanks! I can take care of some documentation jobs! See you TMR in the meeting! |
This looks great! I think it is ready to merge after we replace |
@rlouf I got it. One second |
@rlouf I just made some changes, is that what you suggest? |
I meant |
add some doc remove extra tuple in base.py change log prob to logdensity remove fn from log density
@rlouf Done |
LGTM, thank you for contributing! |
Some initial attempts in integrating MFVI into blackjax #397.
TODO:
There is going to be lots of boiler plate code shared between MFVI and Fullrank VI.
It is also worth considering how to add stick-the-landing gradient estimator [1] and importance-weighted ELBO optimization [2].
[1] Roeder, Geoffrey, Yuhuai Wu, and David K. Duvenaud. "Sticking the landing: Simple, lower-variance gradient estimators for variational inference." Advances in Neural Information Processing Systems 30 (2017).
[2] Domke, Justin, and Daniel R. Sheldon. "Importance weighting and variational inference." Advances in neural information processing systems 31 (2018).