Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow guides with re-initializable parameters #655

Closed

Conversation

ahmadsalim
Copy link
Contributor

@ahmadsalim ahmadsalim commented Jun 30, 2020

The first split of #649 .

It introduces re-initializable guides, which can re-initialize their parameters.
This is useful for Stein inference since the parameters are treated as particles, and we want the particles to be initialized independently instead of all clumped together which would hurt inference.

An alternative would be to have similar strategies which can be called explicitly with a shape and can be used directly in numpyro.param. I do not know which you prefer, so I tried to reuse the existing interface 😄

@ahmadsalim ahmadsalim force-pushed the feature/reinitializable-guides branch from 459a430 to 563dd7b Compare June 30, 2020 10:30
@ahmadsalim ahmadsalim force-pushed the feature/reinitializable-guides branch from 563dd7b to 022ca6e Compare June 30, 2020 10:32
@fehiepsi
Copy link
Member

Thanks for splitting out your PR, @ahmadsalim! IIUC you want two things:

  • initializing param randomly, instead of fixing an initial value
  • using MCMC for param statements?

I think we can start with a contrib.stein or contrib.svgd and move ReinitGuide and AutoDelta (which depends on ReinitGuide) to there. It is good to add some smoke tests to verify that the implementation works as your intention.

FYI, we just recently move autoguide to the main infer module after many refactoring PRs. In addition, we have cleaned up param stuffs in potential_energy and find_initial_valid_params. I am not sure if we need to introduce it again.

How about the following proposals:

  • Implement numpyro.lift to lift param sites to sample sites (this deserves a separate PR - I can implement it if you want). Then you can use the current find_valid_initial_params to get valid initial parameters in the Stein implementation. I guess with that, we also don't have to introduce reinit_param in the current initialization strategy.
  • Having a separate PR for init_with_noise, maybe we can start it in contrib module together with ReinitGuide and AutoDelta? I am also a bit curious about how useful it is. In addition, its interface is a bit different from other initial strategy interfaces. Maybe something like init_with_noise(site=None, noise_scale =1.0, init_strategy=init_to_uniform) makes things more consistent.

@ahmadsalim
Copy link
Contributor Author

Thanks for taking a look!

I do not need MCMC for param statements.
However, the first point is correct, I would like param statements to be randomly initialized.

I would still like them to be categorized as parameters, so that Stein knows that these are the particles to optimize. This is since Stein treats specified parameters as particles, kind of like a generalization of SVI but where we infer multiple values for each parameter using Stein instead of a single one (based on Stein mixture models paper).

@ahmadsalim
Copy link
Contributor Author

I do not have anything against moving things to contrib and will do that ASAP.

Could you kindly explain how numpyro.lift would work?

@fehiepsi
Copy link
Member

I think it will just be a port of http://docs.pyro.ai/en/stable/poutine.html#pyro.poutine.handlers.lift

Do you use those initialization steps in inference? If not, then you can lift to get initial params, then doing inference with your original model as usual.

@ahmadsalim
Copy link
Contributor Author

Oh, that would be great to have! Then I can drop my other changes to the initialization functions 😄

@fehiepsi
Copy link
Member

fehiepsi commented Jul 6, 2020

Sorry for the delay! I'll port pyro.lift today. >>>o<<<

@ahmadsalim
Copy link
Contributor Author

Awesome, thanks a lot!

I was actually not sure what the conclusion was and started looking into implementing it myself 😄 . But I appreciate you looking into it!

@fehiepsi
Copy link
Member

fehiepsi commented Jul 7, 2020

Haha, me too. I thought that you were waiting for it so I just went ahead and implemented. :D

@ahmadsalim
Copy link
Contributor Author

Wow, this is embarrassing. I missed that the pyro.lift function was integrated.
I was debugging a weird NaN gradient/very slow performance issue with JAX for the lasts months and forgot to continue on this path :)

I will figure out how to re-open the PRs as originally planned, but will close this one.

@ahmadsalim ahmadsalim closed this Sep 14, 2020
@ahmadsalim ahmadsalim deleted the feature/reinitializable-guides branch September 14, 2020 08:44
@ahmadsalim
Copy link
Contributor Author

@fehiepsi Thanks a lot for pyro.lift 🥇

@fehiepsi
Copy link
Member

Sure, please let us know if you need any features else. I hope that NumPyro design should not block much of your work, so hopefully, you only need to play with the algorithm. :)

@ahmadsalim
Copy link
Contributor Author

Cool! Thanks a lot :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants