Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CVAE in Flax #1429

Merged
merged 6 commits into from
Jun 21, 2022
Merged

Add CVAE in Flax #1429

merged 6 commits into from
Jun 21, 2022

Conversation

dirmeier
Copy link
Contributor

Hello,

This PR is a port of @fritzo 's CVAE example. It's not particularly fancy, but I believe helpful nonetheless, since (a) it shows how to use optax.multi_transform to freeze a set of parameters while training (i.e., not apply gradient transforms) and use that with SVI, and (b) it gives another Flax example which I think would be good to have more of.

Hope it's useful.

Cheers,
Simon

@dirmeier
Copy link
Contributor Author

I think the modelling tests fail, because flax seems to require installation of pyyaml now?

@fehiepsi
Copy link
Member

That's right. You can merge master to make tests passed.

Is this ready to review? The overall code looks great to me. Maybe it is helpful to add some README to help other people navigate the folder, point to the original Pyro tutorial to illustrate the data and the model. If you want me to look at some specific points, please let me know.

@dirmeier
Copy link
Contributor Author

Is this ready to review? The overall code looks great to me. Maybe it is helpful to add some README to help other people navigate the folder, point to the original Pyro tutorial to illustrate the data and the model. If you want me to look at some specific points, please let me know.

Yes, it is ready to review.

I'd like to wrap model and guide as a Flax module like in the original example here, e.g., something like

from flax import linen as nn

class CVAE(nn.Module):
  
  def setup(self):
    self.baseline = BaselineNet()
    self.prior_net = Encoder()
    ...
   
  def model(self, x, y):
    ...

   def guide(self, x, y):
     ...

Do you know if this works somehow? So far I am getting errors with various attempts.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is tricky to use such pattern. I feel that the current approach is cleaner. No need to mix neural network definition with random variable in the same class.

return state


@jax.jit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove jit here

@dirmeier
Copy link
Contributor Author

@fehiepsi could you run the workflows again please? I don't think it fails bcs of my changes..

fehiepsi
fehiepsi previously approved these changes Jun 16, 2022
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution, @dirmeier! Could you do some extra steps to make this great example more visible for other folks?

@dirmeier
Copy link
Contributor Author

All done, let me know if this works for you or you 'd like to see some other changes..
Cheers,
S

@fehiepsi
Copy link
Member

Thanks @dirmeier!

@fehiepsi fehiepsi merged commit a832572 into pyro-ppl:master Jun 21, 2022
@dirmeier dirmeier deleted the cvae-flax branch June 21, 2022 06:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants