-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CVAE in Flax #1429
Add CVAE in Flax #1429
Conversation
I think the modelling tests fail, because flax seems to require installation of |
That's right. You can merge master to make tests passed. Is this ready to review? The overall code looks great to me. Maybe it is helpful to add some README to help other people navigate the folder, point to the original Pyro tutorial to illustrate the data and the model. If you want me to look at some specific points, please let me know. |
Yes, it is ready to review. I'd like to wrap from flax import linen as nn
class CVAE(nn.Module):
def setup(self):
self.baseline = BaselineNet()
self.prior_net = Encoder()
...
def model(self, x, y):
...
def guide(self, x, y):
... Do you know if this works somehow? So far I am getting errors with various attempts. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is tricky to use such pattern. I feel that the current approach is cleaner. No need to mix neural network definition with random variable in the same class.
examples/cvae-flax/train_baseline.py
Outdated
return state | ||
|
||
|
||
@jax.jit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove jit here
@fehiepsi could you run the workflows again please? I don't think it fails bcs of my changes.. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contribution, @dirmeier! Could you do some extra steps to make this great example more visible for other folks?
- Add a
cvae.py
file with only docstrings like in yourREADME.md
file (you can use this example as a reference). Add a link to github folder of the example (I guess it ishttps://github.com/pyro-ppl/numpyro/tree/master/examples/cvae-flax
) - Move
cvae_predictions.png
to this folder, rename it tocvae.png
- Add
examples/cvae
to the application section of [sphinx]((https://github.com/pyro-ppl/numpyro/blob/master/docs/source/index.rst)
All done, let me know if this works for you or you 'd like to see some other changes.. |
Thanks @dirmeier! |
Hello,
This PR is a port of @fritzo 's CVAE example. It's not particularly fancy, but I believe helpful nonetheless, since (a) it shows how to use
optax.multi_transform
to freeze a set of parameters while training (i.e., not apply gradient transforms) and use that withSVI
, and (b) it gives another Flax example which I think would be good to have more of.Hope it's useful.
Cheers,
Simon