-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initializing batch of parameters for SteinVI #1263
Initializing batch of parameters for SteinVI #1263
Conversation
Hi @OlaRonning, could you add some tests and let's see if we can come up with a solution that does not change how init strategy works? Currently, those strategies are intended to be applied for latent sites, not param sites. I think we can either:
I think this is a better solution for some reasons:
|
Hi @fehiepsi, I’ll add some details to the PR and update with some tests tonight. It should be possible to keep the existing API, but need to sketch your ideas to get feeling for which is superior. |
…in `reinit_guide.py`
for site in guide_trace.values(): | ||
if site["type"] == "param" and not self._reinit_hide_fn(site): | ||
site_key, rng_key = jax.random.split(site_key) | ||
site["kwargs"]["rng_key"] = rng_key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another way to not touch init strategies is to change the site type, fn, is observed keys for this site. (You can use ImproperUniform for fn
.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SteinVI allows parameters without reinitialization to support neural network libraries like Stax
Though the solution in my above comment untouched the current api of stein vi, I still think that letting users define their own initial strategy for each parameter is better. So we don't need to introduce notions like reinit_hide_fn
. But it depends on you. I think that addressing that issue has minor priority.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't particularly like the reinit_hide_fn
as it requires the user to list parameters twice. A solution that eliminates it is superior, in my opinion.
I have been thinking about your suggestion with For |
Hi @fehiepsi, I rewrote the test with batching using your suggestions. The init strategy is now unaffected. ELBO-within-Stein does not spread the variational distributions based on overlap (I think I have a fix for this coming), and SVGD is known to recover well from poor instantiations. So no need to take distribution shape into account (yet) when placing the particles with an |
Found a bug: def _reinit(seed):
with handlers.seed(rng_seed=seed):
return site_fn(*site_args) I understand the above to produce constraint samples when the site_fn is not identity. Is this consistent with current usage? So for example def model():
numpyro.param("b", lambda rng_key: Normal(0, 1.).sample(rng_key), constraint=positive) is invalid. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @OlaRonning!
Is this consistent with current usage?
Yup. lambda key: Normal...
might produce invalid samples for positive constraint.
only loc params (for AutoGuides) and user annotated random params are stochastic.
I like this behavior! init_loc_fn
will initialize values for the latent variables and for parameters, it depends on users.
Thanks for the review and suggestions @fehiepsi! C: |
* Added stein interface. * Fixed style and removed from VI baseclass. * Added reinit_guide.py * Added license. * added examples * Added examples. * Fixed some linting and LDA example; need to refactor wrapped_guide. * Added param site also get rng_keys; this should be reworked! * Removed datasets and fixed lda to running. * Fixed dimensionality bug for simplex support. * Added code from refactor/einstein * Fixed notebooks; todo: comment notebook. * Factored initialization of `kernels.RandomFeatureKernel` into `Stein.init` and updated `test_kernels.test_kernel_forward` accordingly. * Started testing. * Removed assert from test_init_strategy. * Skeleton test_stein.py * Updated `test_stein/test_init` * Added test_params and likelihood computation to lda. * Fix init in MixtureKernel * Notebook fixes * debugging log likelihood * WIP, move benchmarks to datasets * trace guide to compute likelihood in lda. * Debugging LDA * Removed test_vi.py (will use test_stein.py), added `test_stein.test_update_evaluate` * Cleaned test covered by `test_get_params`. * Added skeleton and finished _param_size test. * Fix LR example * IRIS LR * Fix Toy examples * Added pinfo test. * moved stein/test_kernels.py into stein/test_stein.py; updated `test_stein.test_apply_kernel` * Ran black and removed lambdas from KERNEL_TEST_CASE. * Added `test_stein.test_sp_mcmc` and removed calls to jnp.random.shuffle (deprecated). * Added skelelton test for test_score_sp_mcmc. * Fixed overwriting kval in `test_stein.test_apply_kernel` * Fixed lint * Fixed lint. * Added stein_loss test. * Factored vi source and test_vi out of einstein. * updated with black. * Figured out likelihood for LDA (need to change to compute likelihood instead of ELBO) * Added perplexity to LDA. * Fixed log position for perplexity. * Refactored callbacks and added `test_checkpoint`. * Fixed imports * Reverted LDA to working version. * Added callback tests * Return loss history for `stein.run` * Added visual to LDA. * Fixed return for `run_lda` * Added missing topic num 20. * Added todo * Cleaned 1d_mm stein notebook. * Updated 2d gaussian notebook. * Add description to SVGD. * Updated `RBF_kernel` to work with one particle and added kernels notebook. * Fixed bug in bandwidth of RBF_kernel * SVI reproducing result from SVGD paper. * Better learning rate for SVI. * larger network * Updated predictive to allow for particle methods. * Removed TODO and fixed learning rate. * EinStein out performance SVGD * Latest working. * Fixed VI without progressbar. * Fixed mini batching for VI. * Added kernel visualization. * Init to sample for bayesian networks. * TODO predict shape. * Added scaling to plate primitive. * Fixed enumeration in Stein and added subsample_scale to funsor.plate. * Debugging LDA * Debugging lda * Debugging merge. * Updated jacobian computation in Stein. * Fixed issue with nested parameters for stein grad. * Fixed issue with nested parameters for stein grad. * Added NLL to DMM and predictions. * renaming and removing benchmark code * Cleaning branch from benchmarking. * Removed prediction from DMM. * Changed to syntax from older python * Fixed lint. * Fixed reinit warning for `init_to_uniform`. * updated to use black[jupyter] * Added licenses. * Added smoke a smoke test for SteinVI * Factored out Stein point MCMC * Factored out VI from EinStein. * Updated stein_kernels.ipynb and removed debugging pred_prey. * Removed `mixture.py` use `mixtures.py` instead. * Fixed lint. * Added examples to docs build. * Fixed stale import in `hmc.py` docstring. * Removed stein point test cases. * Changed Predictive to only check for guided models with particles. * Fixed lint. * Changed `reinit_guide` to add rng_keys for reinitialization. * Added boston pricing dataset. Commented stein bnn example. Added `stein_bnn.py` to `test_examples.py`. * Removed empty line from `test_examples.py`. * Fixed lint. * Changed `stein_mixture_dmm.py` to use new signature and run method. * Added some comments and fixed `stein_mixture_dmm.py` to use new signature. * Fixed `event_shape` and `support` for `Sine`, `DoubleBanana`, and `Star` distributions in `stein_2d_dists.py`. * Removed notebooks from initial PR and updated stein_2d_toy.py to new run signature. * Parameterize `gru_dim` in `stein_mixture_dmm.py`. * Fixed steinvi to use #1263; TODO: update examples. * Removed init_with_noise. * removed stein_bnn and changed `examples/datasets.py` to upstream * removed examples * Removed stein examples from docs. * renamed einstein.utils to einstein.util. * updated testing * Changed test to use auto_guide `init_loc_fn`. * removed `numpyro/util/ravel_pytree` * removed unused imports in numpyro/util * Added initialization to kernels in `test_einstein_kernels.py` * Changed kernel test to use np.arrays at global level. * change jnp arrays to np np array in tests. reverted subsample scale. * added docstring to `einstein/util/batch_ravel_pytree` Co-authored-by: Ahmad Salim Al-Sibahi <[email protected]> Co-authored-by: einsteinvi <[email protected]>
* Added stein interface. * Fixed style and removed from VI baseclass. * Added reinit_guide.py * Added license. * added examples * Added examples. * Fixed some linting and LDA example; need to refactor wrapped_guide. * Added param site also get rng_keys; this should be reworked! * Removed datasets and fixed lda to running. * Fixed dimensionality bug for simplex support. * Added code from refactor/einstein * Fixed notebooks; todo: comment notebook. * Factored initialization of `kernels.RandomFeatureKernel` into `Stein.init` and updated `test_kernels.test_kernel_forward` accordingly. * Started testing. * Removed assert from test_init_strategy. * Skeleton test_stein.py * Updated `test_stein/test_init` * Added test_params and likelihood computation to lda. * Fix init in MixtureKernel * Notebook fixes * debugging log likelihood * WIP, move benchmarks to datasets * trace guide to compute likelihood in lda. * Debugging LDA * Removed test_vi.py (will use test_stein.py), added `test_stein.test_update_evaluate` * Cleaned test covered by `test_get_params`. * Added skeleton and finished _param_size test. * Fix LR example * IRIS LR * Fix Toy examples * Added pinfo test. * moved stein/test_kernels.py into stein/test_stein.py; updated `test_stein.test_apply_kernel` * Ran black and removed lambdas from KERNEL_TEST_CASE. * Added `test_stein.test_sp_mcmc` and removed calls to jnp.random.shuffle (deprecated). * Added skelelton test for test_score_sp_mcmc. * Fixed overwriting kval in `test_stein.test_apply_kernel` * Fixed lint * Fixed lint. * Added stein_loss test. * Factored vi source and test_vi out of einstein. * updated with black. * Figured out likelihood for LDA (need to change to compute likelihood instead of ELBO) * Added perplexity to LDA. * Fixed log position for perplexity. * Refactored callbacks and added `test_checkpoint`. * Fixed imports * Reverted LDA to working version. * Added callback tests * Return loss history for `stein.run` * Added visual to LDA. * Fixed return for `run_lda` * Added missing topic num 20. * Added todo * Cleaned 1d_mm stein notebook. * Updated 2d gaussian notebook. * Add description to SVGD. * Updated `RBF_kernel` to work with one particle and added kernels notebook. * Fixed bug in bandwidth of RBF_kernel * SVI reproducing result from SVGD paper. * Better learning rate for SVI. * larger network * Updated predictive to allow for particle methods. * Removed TODO and fixed learning rate. * EinStein out performance SVGD * Latest working. * Fixed VI without progressbar. * Fixed mini batching for VI. * Added kernel visualization. * Init to sample for bayesian networks. * TODO predict shape. * Added scaling to plate primitive. * Fixed enumeration in Stein and added subsample_scale to funsor.plate. * Debugging LDA * Debugging lda * Debugging merge. * Updated jacobian computation in Stein. * Fixed issue with nested parameters for stein grad. * Fixed issue with nested parameters for stein grad. * Added NLL to DMM and predictions. * renaming and removing benchmark code * Cleaning branch from benchmarking. * Removed prediction from DMM. * Changed to syntax from older python * Fixed lint. * Fixed reinit warning for `init_to_uniform`. * updated to use black[jupyter] * Added licenses. * Added smoke a smoke test for SteinVI * Factored out Stein point MCMC * Factored out VI from EinStein. * Updated stein_kernels.ipynb and removed debugging pred_prey. * Removed `mixture.py` use `mixtures.py` instead. * Fixed lint. * Added examples to docs build. * Fixed stale import in `hmc.py` docstring. * Removed stein point test cases. * Changed Predictive to only check for guided models with particles. * Fixed lint. * Changed `reinit_guide` to add rng_keys for reinitialization. * Added boston pricing dataset. Commented stein bnn example. Added `stein_bnn.py` to `test_examples.py`. * Removed empty line from `test_examples.py`. * Fixed lint. * Changed `stein_mixture_dmm.py` to use new signature and run method. * Added some comments and fixed `stein_mixture_dmm.py` to use new signature. * Fixed `event_shape` and `support` for `Sine`, `DoubleBanana`, and `Star` distributions in `stein_2d_dists.py`. * Removed notebooks from initial PR and updated stein_2d_toy.py to new run signature. * Parameterize `gru_dim` in `stein_mixture_dmm.py`. * Fixed steinvi to use pyro-ppl#1263; TODO: update examples. * Removed init_with_noise. * removed stein_bnn and changed `examples/datasets.py` to upstream * removed examples * Removed stein examples from docs. * renamed einstein.utils to einstein.util. * updated testing * Changed test to use auto_guide `init_loc_fn`. * removed `numpyro/util/ravel_pytree` * removed unused imports in numpyro/util * Added initialization to kernels in `test_einstein_kernels.py` * Changed kernel test to use np.arrays at global level. * change jnp arrays to np np array in tests. reverted subsample scale. * added docstring to `einstein/util/batch_ravel_pytree` Co-authored-by: Ahmad Salim Al-Sibahi <[email protected]> Co-authored-by: einsteinvi <[email protected]>
* added stein example * added test case * added stein bnn to docs. * moveed stein_bnn to other inf algs in docs * Added correct plating for model in `stein_bnn.py`. Works with latest #833. * Add some doctests to transforms (#1300) * add some doctest to transforms * make format * Tutorial for truncated distributions (#1272) * WIP Do not merge. Tutorial for truncated distributions * WIP: Completed a few todos and fixed a few typos * WIP: Completed main sections. References and part 5 still pending * Added section on built in distributions and folded distributions * Draft ready * Remove M1-related warning from cell output * Truncated distributions tutorial added to index * Wrap latex equations in double dollar sign * Fix broken markdown equations * Added more details on folded distribs. Re-arranged sections. * Test: Change title level. * Links now point to the docs instead of the source code. Fixed some broken formatting of the titles. Use different seeds for Prior/Inference/Prediction. Changed models for inferring the truncation. Fixed minor typos. * Install numpyro and upgrade jax, jaxlib and matplotlib Copy jax arrays before passing to matplotlib functions * Clarified statement about the log_prob method in the TruncatedDistribution class. * Changed intro sentence to include folded distributions. * Remove command for installing jax. Use np.unique instead of jnp.unique * Cast rate parameter to float (#1301) * Make potential_fn_gen and postprocess_fn_gen picklable (#1302) * add wrapper * Make potential_fn_gen postprocess_fn_gen pickable * Stein based inference (#833) * Added stein interface. * Fixed style and removed from VI baseclass. * Added reinit_guide.py * Added license. * added examples * Added examples. * Fixed some linting and LDA example; need to refactor wrapped_guide. * Added param site also get rng_keys; this should be reworked! * Removed datasets and fixed lda to running. * Fixed dimensionality bug for simplex support. * Added code from refactor/einstein * Fixed notebooks; todo: comment notebook. * Factored initialization of `kernels.RandomFeatureKernel` into `Stein.init` and updated `test_kernels.test_kernel_forward` accordingly. * Started testing. * Removed assert from test_init_strategy. * Skeleton test_stein.py * Updated `test_stein/test_init` * Added test_params and likelihood computation to lda. * Fix init in MixtureKernel * Notebook fixes * debugging log likelihood * WIP, move benchmarks to datasets * trace guide to compute likelihood in lda. * Debugging LDA * Removed test_vi.py (will use test_stein.py), added `test_stein.test_update_evaluate` * Cleaned test covered by `test_get_params`. * Added skeleton and finished _param_size test. * Fix LR example * IRIS LR * Fix Toy examples * Added pinfo test. * moved stein/test_kernels.py into stein/test_stein.py; updated `test_stein.test_apply_kernel` * Ran black and removed lambdas from KERNEL_TEST_CASE. * Added `test_stein.test_sp_mcmc` and removed calls to jnp.random.shuffle (deprecated). * Added skelelton test for test_score_sp_mcmc. * Fixed overwriting kval in `test_stein.test_apply_kernel` * Fixed lint * Fixed lint. * Added stein_loss test. * Factored vi source and test_vi out of einstein. * updated with black. * Figured out likelihood for LDA (need to change to compute likelihood instead of ELBO) * Added perplexity to LDA. * Fixed log position for perplexity. * Refactored callbacks and added `test_checkpoint`. * Fixed imports * Reverted LDA to working version. * Added callback tests * Return loss history for `stein.run` * Added visual to LDA. * Fixed return for `run_lda` * Added missing topic num 20. * Added todo * Cleaned 1d_mm stein notebook. * Updated 2d gaussian notebook. * Add description to SVGD. * Updated `RBF_kernel` to work with one particle and added kernels notebook. * Fixed bug in bandwidth of RBF_kernel * SVI reproducing result from SVGD paper. * Better learning rate for SVI. * larger network * Updated predictive to allow for particle methods. * Removed TODO and fixed learning rate. * EinStein out performance SVGD * Latest working. * Fixed VI without progressbar. * Fixed mini batching for VI. * Added kernel visualization. * Init to sample for bayesian networks. * TODO predict shape. * Added scaling to plate primitive. * Fixed enumeration in Stein and added subsample_scale to funsor.plate. * Debugging LDA * Debugging lda * Debugging merge. * Updated jacobian computation in Stein. * Fixed issue with nested parameters for stein grad. * Fixed issue with nested parameters for stein grad. * Added NLL to DMM and predictions. * renaming and removing benchmark code * Cleaning branch from benchmarking. * Removed prediction from DMM. * Changed to syntax from older python * Fixed lint. * Fixed reinit warning for `init_to_uniform`. * updated to use black[jupyter] * Added licenses. * Added smoke a smoke test for SteinVI * Factored out Stein point MCMC * Factored out VI from EinStein. * Updated stein_kernels.ipynb and removed debugging pred_prey. * Removed `mixture.py` use `mixtures.py` instead. * Fixed lint. * Added examples to docs build. * Fixed stale import in `hmc.py` docstring. * Removed stein point test cases. * Changed Predictive to only check for guided models with particles. * Fixed lint. * Changed `reinit_guide` to add rng_keys for reinitialization. * Added boston pricing dataset. Commented stein bnn example. Added `stein_bnn.py` to `test_examples.py`. * Removed empty line from `test_examples.py`. * Fixed lint. * Changed `stein_mixture_dmm.py` to use new signature and run method. * Added some comments and fixed `stein_mixture_dmm.py` to use new signature. * Fixed `event_shape` and `support` for `Sine`, `DoubleBanana`, and `Star` distributions in `stein_2d_dists.py`. * Removed notebooks from initial PR and updated stein_2d_toy.py to new run signature. * Parameterize `gru_dim` in `stein_mixture_dmm.py`. * Fixed steinvi to use #1263; TODO: update examples. * Removed init_with_noise. * removed stein_bnn and changed `examples/datasets.py` to upstream * removed examples * Removed stein examples from docs. * renamed einstein.utils to einstein.util. * updated testing * Changed test to use auto_guide `init_loc_fn`. * removed `numpyro/util/ravel_pytree` * removed unused imports in numpyro/util * Added initialization to kernels in `test_einstein_kernels.py` * Changed kernel test to use np.arrays at global level. * change jnp arrays to np np array in tests. reverted subsample scale. * added docstring to `einstein/util/batch_ravel_pytree` Co-authored-by: Ahmad Salim Al-Sibahi <[email protected]> Co-authored-by: einsteinvi <[email protected]> * Improve subsample warning keys (#1303) * Add ProvenanceArray to infer relational structure in a model (#1248) * Add provenance array * Add tests for provenance * run make format * Workaround not be able to eval_shape a distribution * Make license * add a clearer guide for render a model with scan * fix failing bugs in recent jax release * Fix further failing tests * Make sure to be able to render ImproperUniform and random initialized params * port get_dependencies to numpyro * tighten test_improper_normal bound (#1307) * Fix HMCECS multiple plates (#1305) * Add Kumaraswamy and relaxed Bernoulli distributions (#1283) * Add kumaraswamy and relaxed bernoulli distributions * clean up the flag * Require logits to be keyword argument * make relaxed bernoulli have the same signature as Pyro * fix docs build * Fix rsample bug * add more simple test for Kumaraswamy * Add various KL divergences for Gamma/Beta families (#1284) * Add new distributions and kl * Add kumaraswamy and relaxed bernoulli distributions * clean up the flag * Require logits to be keyword argument * make relaxed bernoulli have the same signature as Pyro * fix docs build * Fix rsample bug * move the flag to Kumaraswamy class for convenient * Add loose strategy for missing plates in MCMC (#1304) * Add loose strategy for MCMC * merge svi and mcmc plate warning strategies * fix failing tests * validate model accross ELBOs * update vae example * fix typos * Fix failing tests * skip prodlda test on CI * Bump to 0.9.0 (#1310) * Add loose strategy for MCMC * merge svi and mcmc plate warning strategies * fix failing tests * validate model accross ELBOs * update vae example * fix typos * Bump to version 0.9.0 * Fix failing tests * Fix warnings in tests/examples * relax funsor requirement * Move optax_to_numpyro to optim * skip prodlda test on CI * added dimensions to plate and sqrt precision. * fixed/added comments in stein_bnn.py and removed lr datasets. * added comment to stein_bnn.py * formatted files to black==22.1.0 Co-authored-by: Wataru Hashimoto <[email protected]> Co-authored-by: Omar Sosa Rodríguez <[email protected]> Co-authored-by: Vedran Hadziosmanovic <[email protected]> Co-authored-by: Du Phan <[email protected]> Co-authored-by: Ahmad Salim Al-Sibahi <[email protected]> Co-authored-by: einsteinvi <[email protected]> Co-authored-by: austereantelope <[email protected]>
To support re-initializable guides, we suggest the
WrappedGuide
interface, which requires implementing a functionfind_params
that accepts a list of RNG keys in addition to the arguments for the guide and returns a set of freshly initialized parameters for each RNG key.The
WrappedGuide
class provides a guide written as a function.WrappedGuide
makes a callable guide re-initializable. It works by running the provided guide multiple times and reinitializing the parameters using NumPyro's interface as follows:init_to_uniform
,init_to_median
, orinit_to_noise
).To avoid reinitializing specific parameters,
WrappedGuide
implementsreinit_hide_fn
to filter the parameters. SteinVI allows parameters without reinitialization to support neural network libraries like Stax, which provide bespoke initializers.An alternative solution is to lift parameters to sample sites as suggested in #655 (and again by @fehiepsi below). Related to #833.