Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initializing batch of parameters for SteinVI #1263

Merged
merged 17 commits into from
Jan 7, 2022

Conversation

OlaRonning
Copy link
Member

@OlaRonning OlaRonning commented Dec 27, 2021

To support re-initializable guides, we suggest the WrappedGuide interface, which requires implementing a function find_params that accepts a list of RNG keys in addition to the arguments for the guide and returns a set of freshly initialized parameters for each RNG key.

class ReinitGuide(ABC):
   @abstractmethod
   def init_params(self):
       raise NotImplementedError

   @abstractmethod
   def find_params(self, rng_keys, *args, **kwargs):
       raise NotImplementedError

The WrappedGuide class provides a guide written as a function. WrappedGuide makes a callable guide re-initializable. It works by running the provided guide multiple times and reinitializing the parameters using NumPyro's interface as follows:

  1. WrappedGuide runs the guide transforming each parameter to unconstrained space.
  2. For each parameter in the guide, WrappedGuide sets an RNG key.
  3. It replaces the values of the parameters with values provided by a NumPyro initialization strategy (e.g., init_to_uniform, init_to_median, or init_to_noise).
  4. It saves the parameter values for each particle and the required inverse transformations to constrained space to run the model correctly.

To avoid reinitializing specific parameters, WrappedGuide implements reinit_hide_fn to filter the parameters. SteinVI allows parameters without reinitialization to support neural network libraries like Stax, which provide bespoke initializers.

An alternative solution is to lift parameters to sample sites as suggested in #655 (and again by @fehiepsi below). Related to #833.

@OlaRonning OlaRonning added the WIP label Dec 27, 2021
@fehiepsi
Copy link
Member

Hi @OlaRonning, could you add some tests and let's see if we can come up with a solution that does not change how init strategy works? Currently, those strategies are intended to be applied for latent sites, not param sites. I think we can either:

  • lift params in your model/guide to sample sites. Like ReinitGuide, this solution is kind of "hack" because modifies user programs (e.g. even though users already define init_value=1. for those params, Stein VI will still provide initial random values for them)
  • explicitly requires users to provide random statements in param for init_value.
numpyro.param("a", lambda rng: ...)

I think this is a better solution for some reasons:

  • The user program can be interpreted as-is, i.e. constant init values will be constant init values, random init values will be random init values.
  • We don't need to introduce ways (like lift or ReinitGuide) to modify user programs
  • We can leverage many init strategies in other jax nn frameworks.

@OlaRonning
Copy link
Member Author

Hi @fehiepsi,

I’ll add some details to the PR and update with some tests tonight. It should be possible to keep the existing API, but need to sketch your ideas to get feeling for which is superior.

@OlaRonning OlaRonning changed the title added reinit guide Initializing batch of parameters for SteinVI Dec 28, 2021
@OlaRonning OlaRonning removed the WIP label Dec 28, 2021
for site in guide_trace.values():
if site["type"] == "param" and not self._reinit_hide_fn(site):
site_key, rng_key = jax.random.split(site_key)
site["kwargs"]["rng_key"] = rng_key
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way to not touch init strategies is to change the site type, fn, is observed keys for this site. (You can use ImproperUniform for fn.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SteinVI allows parameters without reinitialization to support neural network libraries like Stax

Though the solution in my above comment untouched the current api of stein vi, I still think that letting users define their own initial strategy for each parameter is better. So we don't need to introduce notions like reinit_hide_fn. But it depends on you. I think that addressing that issue has minor priority.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't particularly like the reinit_hide_fn as it requires the user to list parameters twice. A solution that eliminates it is superior, in my opinion.

@OlaRonning
Copy link
Member Author

OlaRonning commented Dec 29, 2021

I have been thinking about your suggestion with numpyro.param("a", lambda rng: ...) and I agree it's the better (more flexible) solution. The SteinVI.init will instantiate the batch (particle) dimension for all parameters.

For AutoGuides we would still use the init_loc_fn with a different init for custom guides and AutoGuides. I'll sketch the solution in the tests and we can discuss extensions/alternatives.

@OlaRonning
Copy link
Member Author

Hi @fehiepsi, I rewrote the test with batching using your suggestions. The init strategy is now unaffected.
I changed the semantics slightly. Now only loc params (for AutoGuides) and user annotated random params are stochastic.

ELBO-within-Stein does not spread the variational distributions based on overlap (I think I have a fix for this coming), and SVGD is known to recover well from poor instantiations. So no need to take distribution shape into account (yet) when placing the particles with an AutoGuide.

@OlaRonning OlaRonning requested a review from fehiepsi January 6, 2022 08:21
@OlaRonning
Copy link
Member Author

Found a bug: callable(site['value']) is always false, so all particles were initiziated to the same value for custom guides. Instead, I now vmap

def _reinit(seed):
  with handlers.seed(rng_seed=seed):
    return site_fn(*site_args)

I understand the above to produce constraint samples when the site_fn is not identity. Is this consistent with current usage? So for example

def model():
  numpyro.param("b", lambda rng_key: Normal(0, 1.).sample(rng_key), constraint=positive)

is invalid.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @OlaRonning!

Is this consistent with current usage?

Yup. lambda key: Normal... might produce invalid samples for positive constraint.

only loc params (for AutoGuides) and user annotated random params are stochastic.

I like this behavior! init_loc_fn will initialize values for the latent variables and for parameters, it depends on users.

@fehiepsi fehiepsi merged commit cbc371e into pyro-ppl:master Jan 7, 2022
@OlaRonning OlaRonning deleted the feature/reinit_guide branch January 7, 2022 15:11
@OlaRonning
Copy link
Member Author

Thanks for the review and suggestions @fehiepsi! C:

OlaRonning added a commit to aleatory-science/numpyro that referenced this pull request Jan 7, 2022
@OlaRonning OlaRonning mentioned this pull request Jan 8, 2022
8 tasks
fehiepsi pushed a commit that referenced this pull request Jan 23, 2022
* Added stein interface.

* Fixed style and removed from VI baseclass.

* Added reinit_guide.py

* Added license.

* added examples

* Added examples.

* Fixed some linting and LDA example; need to refactor wrapped_guide.

* Added param site also get rng_keys; this should be reworked!

* Removed datasets and fixed lda to running.

* Fixed dimensionality bug for simplex support.

* Added code from refactor/einstein

* Fixed notebooks; todo: comment notebook.

* Factored initialization of `kernels.RandomFeatureKernel` into `Stein.init` and updated `test_kernels.test_kernel_forward` accordingly.

* Started testing.

* Removed assert from test_init_strategy.

* Skeleton test_stein.py

* Updated `test_stein/test_init`

* Added test_params and likelihood computation to lda.

* Fix init in MixtureKernel

* Notebook fixes

* debugging log likelihood

* WIP, move benchmarks to datasets

* trace guide to compute likelihood in lda.

* Debugging LDA

* Removed test_vi.py (will use test_stein.py), added `test_stein.test_update_evaluate`

* Cleaned test covered by `test_get_params`.

* Added skeleton and finished _param_size test.

* Fix LR example

* IRIS LR

* Fix Toy examples

* Added pinfo test.

* moved stein/test_kernels.py into stein/test_stein.py; updated `test_stein.test_apply_kernel`

* Ran black and removed lambdas from KERNEL_TEST_CASE.

* Added `test_stein.test_sp_mcmc` and removed calls to jnp.random.shuffle (deprecated).

* Added skelelton test for test_score_sp_mcmc.

* Fixed overwriting kval in `test_stein.test_apply_kernel`

* Fixed lint

* Fixed lint.

* Added stein_loss test.

* Factored vi source and test_vi out of einstein.

* updated with black.

* Figured out likelihood for LDA (need to change to compute likelihood instead of ELBO)

* Added perplexity to LDA.

* Fixed log position for perplexity.

* Refactored callbacks and added `test_checkpoint`.

* Fixed imports

* Reverted LDA to working version.

* Added callback tests

* Return loss history for `stein.run`

* Added visual to LDA.

* Fixed return for `run_lda`

* Added missing topic num 20.

* Added todo

* Cleaned 1d_mm stein notebook.

* Updated 2d gaussian notebook.

* Add description to SVGD.

* Updated `RBF_kernel` to work with one particle and added kernels notebook.

* Fixed bug in bandwidth of RBF_kernel

* SVI reproducing result from SVGD paper.

* Better learning rate for SVI.

* larger network

* Updated predictive to allow for particle methods.

* Removed TODO and fixed learning rate.

* EinStein out performance SVGD

* Latest working.

* Fixed VI without progressbar.

* Fixed mini batching for VI.

* Added kernel visualization.

* Init to sample for bayesian networks.

* TODO predict shape.

* Added scaling to plate primitive.

* Fixed enumeration in Stein and added subsample_scale to funsor.plate.

* Debugging LDA

* Debugging lda

* Debugging merge.

* Updated jacobian computation in Stein.

* Fixed issue with nested parameters for stein grad.

* Fixed issue with nested parameters for stein grad.

* Added NLL to DMM and predictions.

* renaming and removing benchmark code

* Cleaning branch from benchmarking.

* Removed prediction from DMM.

* Changed to syntax from older python

* Fixed lint.

* Fixed reinit warning for `init_to_uniform`.

* updated to use black[jupyter]

* Added licenses.

* Added smoke a smoke test for SteinVI

* Factored out Stein point MCMC

* Factored out VI from EinStein.

* Updated stein_kernels.ipynb and removed debugging pred_prey.

* Removed `mixture.py` use `mixtures.py` instead.

* Fixed lint.

* Added examples to docs build.

* Fixed stale import in `hmc.py` docstring.

* Removed stein point test cases.

* Changed Predictive to only check for guided models with particles.

* Fixed lint.

* Changed `reinit_guide` to add rng_keys for reinitialization.

* Added boston pricing dataset. Commented stein bnn example. Added `stein_bnn.py` to `test_examples.py`.

* Removed empty line from `test_examples.py`.

* Fixed lint.

* Changed `stein_mixture_dmm.py` to use new signature and run method.

* Added some comments and fixed `stein_mixture_dmm.py` to use new signature.

* Fixed `event_shape` and `support` for `Sine`, `DoubleBanana`, and `Star` distributions in `stein_2d_dists.py`.

* Removed notebooks from initial PR and updated stein_2d_toy.py to new run signature.

* Parameterize `gru_dim` in `stein_mixture_dmm.py`.

* Fixed steinvi to use #1263; TODO: update examples.

* Removed init_with_noise.

* removed stein_bnn and changed `examples/datasets.py` to upstream

* removed examples

* Removed stein examples from docs.

* renamed einstein.utils to einstein.util.

* updated testing

* Changed test to use auto_guide `init_loc_fn`.

* removed `numpyro/util/ravel_pytree`

* removed unused imports in numpyro/util

* Added initialization to kernels in `test_einstein_kernels.py`

* Changed kernel test to use np.arrays at global level.

* change jnp arrays to np np array in tests. reverted subsample scale.

* added docstring to `einstein/util/batch_ravel_pytree`

Co-authored-by: Ahmad Salim Al-Sibahi <[email protected]>
Co-authored-by: einsteinvi <[email protected]>
OlaRonning added a commit to aleatory-science/numpyro that referenced this pull request Jan 31, 2022
* Added stein interface.

* Fixed style and removed from VI baseclass.

* Added reinit_guide.py

* Added license.

* added examples

* Added examples.

* Fixed some linting and LDA example; need to refactor wrapped_guide.

* Added param site also get rng_keys; this should be reworked!

* Removed datasets and fixed lda to running.

* Fixed dimensionality bug for simplex support.

* Added code from refactor/einstein

* Fixed notebooks; todo: comment notebook.

* Factored initialization of `kernels.RandomFeatureKernel` into `Stein.init` and updated `test_kernels.test_kernel_forward` accordingly.

* Started testing.

* Removed assert from test_init_strategy.

* Skeleton test_stein.py

* Updated `test_stein/test_init`

* Added test_params and likelihood computation to lda.

* Fix init in MixtureKernel

* Notebook fixes

* debugging log likelihood

* WIP, move benchmarks to datasets

* trace guide to compute likelihood in lda.

* Debugging LDA

* Removed test_vi.py (will use test_stein.py), added `test_stein.test_update_evaluate`

* Cleaned test covered by `test_get_params`.

* Added skeleton and finished _param_size test.

* Fix LR example

* IRIS LR

* Fix Toy examples

* Added pinfo test.

* moved stein/test_kernels.py into stein/test_stein.py; updated `test_stein.test_apply_kernel`

* Ran black and removed lambdas from KERNEL_TEST_CASE.

* Added `test_stein.test_sp_mcmc` and removed calls to jnp.random.shuffle (deprecated).

* Added skelelton test for test_score_sp_mcmc.

* Fixed overwriting kval in `test_stein.test_apply_kernel`

* Fixed lint

* Fixed lint.

* Added stein_loss test.

* Factored vi source and test_vi out of einstein.

* updated with black.

* Figured out likelihood for LDA (need to change to compute likelihood instead of ELBO)

* Added perplexity to LDA.

* Fixed log position for perplexity.

* Refactored callbacks and added `test_checkpoint`.

* Fixed imports

* Reverted LDA to working version.

* Added callback tests

* Return loss history for `stein.run`

* Added visual to LDA.

* Fixed return for `run_lda`

* Added missing topic num 20.

* Added todo

* Cleaned 1d_mm stein notebook.

* Updated 2d gaussian notebook.

* Add description to SVGD.

* Updated `RBF_kernel` to work with one particle and added kernels notebook.

* Fixed bug in bandwidth of RBF_kernel

* SVI reproducing result from SVGD paper.

* Better learning rate for SVI.

* larger network

* Updated predictive to allow for particle methods.

* Removed TODO and fixed learning rate.

* EinStein out performance SVGD

* Latest working.

* Fixed VI without progressbar.

* Fixed mini batching for VI.

* Added kernel visualization.

* Init to sample for bayesian networks.

* TODO predict shape.

* Added scaling to plate primitive.

* Fixed enumeration in Stein and added subsample_scale to funsor.plate.

* Debugging LDA

* Debugging lda

* Debugging merge.

* Updated jacobian computation in Stein.

* Fixed issue with nested parameters for stein grad.

* Fixed issue with nested parameters for stein grad.

* Added NLL to DMM and predictions.

* renaming and removing benchmark code

* Cleaning branch from benchmarking.

* Removed prediction from DMM.

* Changed to syntax from older python

* Fixed lint.

* Fixed reinit warning for `init_to_uniform`.

* updated to use black[jupyter]

* Added licenses.

* Added smoke a smoke test for SteinVI

* Factored out Stein point MCMC

* Factored out VI from EinStein.

* Updated stein_kernels.ipynb and removed debugging pred_prey.

* Removed `mixture.py` use `mixtures.py` instead.

* Fixed lint.

* Added examples to docs build.

* Fixed stale import in `hmc.py` docstring.

* Removed stein point test cases.

* Changed Predictive to only check for guided models with particles.

* Fixed lint.

* Changed `reinit_guide` to add rng_keys for reinitialization.

* Added boston pricing dataset. Commented stein bnn example. Added `stein_bnn.py` to `test_examples.py`.

* Removed empty line from `test_examples.py`.

* Fixed lint.

* Changed `stein_mixture_dmm.py` to use new signature and run method.

* Added some comments and fixed `stein_mixture_dmm.py` to use new signature.

* Fixed `event_shape` and `support` for `Sine`, `DoubleBanana`, and `Star` distributions in `stein_2d_dists.py`.

* Removed notebooks from initial PR and updated stein_2d_toy.py to new run signature.

* Parameterize `gru_dim` in `stein_mixture_dmm.py`.

* Fixed steinvi to use pyro-ppl#1263; TODO: update examples.

* Removed init_with_noise.

* removed stein_bnn and changed `examples/datasets.py` to upstream

* removed examples

* Removed stein examples from docs.

* renamed einstein.utils to einstein.util.

* updated testing

* Changed test to use auto_guide `init_loc_fn`.

* removed `numpyro/util/ravel_pytree`

* removed unused imports in numpyro/util

* Added initialization to kernels in `test_einstein_kernels.py`

* Changed kernel test to use np.arrays at global level.

* change jnp arrays to np np array in tests. reverted subsample scale.

* added docstring to `einstein/util/batch_ravel_pytree`

Co-authored-by: Ahmad Salim Al-Sibahi <[email protected]>
Co-authored-by: einsteinvi <[email protected]>
fehiepsi added a commit that referenced this pull request Jan 31, 2022
* added stein example

* added test case

* added stein bnn to docs.

* moveed stein_bnn to other inf algs in docs

* Added correct plating for model in `stein_bnn.py`. Works with latest #833.

* Add some doctests to transforms  (#1300)

* add some doctest to transforms

* make format

* Tutorial for truncated distributions (#1272)

* WIP Do not merge. Tutorial for truncated distributions

* WIP: Completed a few todos and fixed a few typos

* WIP: Completed main sections. References and part 5 still pending

* Added section on built in distributions and folded distributions

* Draft ready

* Remove M1-related warning from cell output

* Truncated distributions tutorial added to index

* Wrap latex equations in double dollar sign

* Fix broken markdown equations

* Added more details on folded distribs. Re-arranged sections.

* Test: Change title level.

* Links now point to the docs instead of the source code.
Fixed some broken formatting of the titles.
Use different seeds for Prior/Inference/Prediction.
Changed models for inferring the truncation.
Fixed minor typos.

* Install numpyro and upgrade jax, jaxlib and matplotlib
Copy jax arrays before passing to matplotlib functions

* Clarified statement about the log_prob method in the
TruncatedDistribution class.

* Changed intro sentence to include folded distributions.

* Remove command for installing jax.
Use np.unique instead of jnp.unique

* Cast rate parameter to float (#1301)

* Make potential_fn_gen and postprocess_fn_gen picklable (#1302)

* add wrapper

* Make potential_fn_gen postprocess_fn_gen pickable

* Stein based inference (#833)

* Added stein interface.

* Fixed style and removed from VI baseclass.

* Added reinit_guide.py

* Added license.

* added examples

* Added examples.

* Fixed some linting and LDA example; need to refactor wrapped_guide.

* Added param site also get rng_keys; this should be reworked!

* Removed datasets and fixed lda to running.

* Fixed dimensionality bug for simplex support.

* Added code from refactor/einstein

* Fixed notebooks; todo: comment notebook.

* Factored initialization of `kernels.RandomFeatureKernel` into `Stein.init` and updated `test_kernels.test_kernel_forward` accordingly.

* Started testing.

* Removed assert from test_init_strategy.

* Skeleton test_stein.py

* Updated `test_stein/test_init`

* Added test_params and likelihood computation to lda.

* Fix init in MixtureKernel

* Notebook fixes

* debugging log likelihood

* WIP, move benchmarks to datasets

* trace guide to compute likelihood in lda.

* Debugging LDA

* Removed test_vi.py (will use test_stein.py), added `test_stein.test_update_evaluate`

* Cleaned test covered by `test_get_params`.

* Added skeleton and finished _param_size test.

* Fix LR example

* IRIS LR

* Fix Toy examples

* Added pinfo test.

* moved stein/test_kernels.py into stein/test_stein.py; updated `test_stein.test_apply_kernel`

* Ran black and removed lambdas from KERNEL_TEST_CASE.

* Added `test_stein.test_sp_mcmc` and removed calls to jnp.random.shuffle (deprecated).

* Added skelelton test for test_score_sp_mcmc.

* Fixed overwriting kval in `test_stein.test_apply_kernel`

* Fixed lint

* Fixed lint.

* Added stein_loss test.

* Factored vi source and test_vi out of einstein.

* updated with black.

* Figured out likelihood for LDA (need to change to compute likelihood instead of ELBO)

* Added perplexity to LDA.

* Fixed log position for perplexity.

* Refactored callbacks and added `test_checkpoint`.

* Fixed imports

* Reverted LDA to working version.

* Added callback tests

* Return loss history for `stein.run`

* Added visual to LDA.

* Fixed return for `run_lda`

* Added missing topic num 20.

* Added todo

* Cleaned 1d_mm stein notebook.

* Updated 2d gaussian notebook.

* Add description to SVGD.

* Updated `RBF_kernel` to work with one particle and added kernels notebook.

* Fixed bug in bandwidth of RBF_kernel

* SVI reproducing result from SVGD paper.

* Better learning rate for SVI.

* larger network

* Updated predictive to allow for particle methods.

* Removed TODO and fixed learning rate.

* EinStein out performance SVGD

* Latest working.

* Fixed VI without progressbar.

* Fixed mini batching for VI.

* Added kernel visualization.

* Init to sample for bayesian networks.

* TODO predict shape.

* Added scaling to plate primitive.

* Fixed enumeration in Stein and added subsample_scale to funsor.plate.

* Debugging LDA

* Debugging lda

* Debugging merge.

* Updated jacobian computation in Stein.

* Fixed issue with nested parameters for stein grad.

* Fixed issue with nested parameters for stein grad.

* Added NLL to DMM and predictions.

* renaming and removing benchmark code

* Cleaning branch from benchmarking.

* Removed prediction from DMM.

* Changed to syntax from older python

* Fixed lint.

* Fixed reinit warning for `init_to_uniform`.

* updated to use black[jupyter]

* Added licenses.

* Added smoke a smoke test for SteinVI

* Factored out Stein point MCMC

* Factored out VI from EinStein.

* Updated stein_kernels.ipynb and removed debugging pred_prey.

* Removed `mixture.py` use `mixtures.py` instead.

* Fixed lint.

* Added examples to docs build.

* Fixed stale import in `hmc.py` docstring.

* Removed stein point test cases.

* Changed Predictive to only check for guided models with particles.

* Fixed lint.

* Changed `reinit_guide` to add rng_keys for reinitialization.

* Added boston pricing dataset. Commented stein bnn example. Added `stein_bnn.py` to `test_examples.py`.

* Removed empty line from `test_examples.py`.

* Fixed lint.

* Changed `stein_mixture_dmm.py` to use new signature and run method.

* Added some comments and fixed `stein_mixture_dmm.py` to use new signature.

* Fixed `event_shape` and `support` for `Sine`, `DoubleBanana`, and `Star` distributions in `stein_2d_dists.py`.

* Removed notebooks from initial PR and updated stein_2d_toy.py to new run signature.

* Parameterize `gru_dim` in `stein_mixture_dmm.py`.

* Fixed steinvi to use #1263; TODO: update examples.

* Removed init_with_noise.

* removed stein_bnn and changed `examples/datasets.py` to upstream

* removed examples

* Removed stein examples from docs.

* renamed einstein.utils to einstein.util.

* updated testing

* Changed test to use auto_guide `init_loc_fn`.

* removed `numpyro/util/ravel_pytree`

* removed unused imports in numpyro/util

* Added initialization to kernels in `test_einstein_kernels.py`

* Changed kernel test to use np.arrays at global level.

* change jnp arrays to np np array in tests. reverted subsample scale.

* added docstring to `einstein/util/batch_ravel_pytree`

Co-authored-by: Ahmad Salim Al-Sibahi <[email protected]>
Co-authored-by: einsteinvi <[email protected]>

* Improve subsample warning keys (#1303)

* Add ProvenanceArray to infer relational structure in a model (#1248)

* Add provenance array

* Add tests for provenance

* run make format

* Workaround not be able to eval_shape a distribution

* Make license

* add a clearer guide for render a model with scan

* fix failing bugs in recent jax release

* Fix further failing tests

* Make sure to be able to render ImproperUniform and random initialized params

* port get_dependencies to numpyro

* tighten test_improper_normal bound (#1307)

* Fix HMCECS multiple plates (#1305)

* Add Kumaraswamy and relaxed Bernoulli distributions (#1283)

* Add kumaraswamy and relaxed bernoulli distributions

* clean up the flag

* Require logits to be keyword argument

* make relaxed bernoulli have the same signature as Pyro

* fix docs build

* Fix rsample bug

* add more simple test for Kumaraswamy

* Add various KL divergences for Gamma/Beta families (#1284)

* Add new distributions and kl

* Add kumaraswamy and relaxed bernoulli distributions

* clean up the flag

* Require logits to be keyword argument

* make relaxed bernoulli have the same signature as Pyro

* fix docs build

* Fix rsample bug

* move the flag to Kumaraswamy class for convenient

* Add loose strategy for missing plates in MCMC (#1304)

* Add loose strategy for MCMC

* merge svi and mcmc plate warning strategies

* fix failing tests

* validate model accross ELBOs

* update vae example

* fix typos

* Fix failing tests

* skip prodlda test on CI

* Bump to 0.9.0 (#1310)

* Add loose strategy for MCMC

* merge svi and mcmc plate warning strategies

* fix failing tests

* validate model accross ELBOs

* update vae example

* fix typos

* Bump to version 0.9.0

* Fix failing tests

* Fix warnings in tests/examples

* relax funsor requirement

* Move optax_to_numpyro to optim

* skip prodlda test on CI

* added dimensions to plate and sqrt precision.

* fixed/added comments in stein_bnn.py and removed lr datasets.

* added comment to stein_bnn.py

* formatted files to black==22.1.0

Co-authored-by: Wataru Hashimoto <[email protected]>
Co-authored-by: Omar Sosa Rodríguez <[email protected]>
Co-authored-by: Vedran Hadziosmanovic <[email protected]>
Co-authored-by: Du Phan <[email protected]>
Co-authored-by: Ahmad Salim Al-Sibahi <[email protected]>
Co-authored-by: einsteinvi <[email protected]>
Co-authored-by: austereantelope <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants