-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add loose strategy for missing plates in MCMC #1304
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm. this has no effect on the current svi behavior, right?
Yes, but we can merge two implementations. The svi version is better because it also handles singleton dimensions. Let's me merge them. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm. @fehiepsi is this still wip?
Thanks for reviewing, @martinjankowiak! I just fixed some tests due to new jaxlib/scikit-learn release:
Actually, it seems that prodlda is failing due to high memory requirement in CI (not so sure). It passed locally so I just added a mark to skip it in CI. |
@@ -569,6 +574,8 @@ def single_particle_elbo(rng_key): | |||
model_trace, guide_trace = get_importance_trace( | |||
seeded_model, seeded_guide, args, kwargs, param_map | |||
) | |||
check_model_guide_match(model_trace, guide_trace) | |||
_validate_model(model_trace, plate_warning="strict") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why strict here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm seeing TraceGraphELBO leverages plate stacks etc to do inference. I'm not sure if strict is required. If you think using loose is enough, I can make the change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me merge this first, then would make the change if needed in #1310
* Add loose strategy for MCMC * merge svi and mcmc plate warning strategies * fix failing tests * validate model accross ELBOs * update vae example * fix typos * Fix failing tests * skip prodlda test on CI
* added stein example * added test case * added stein bnn to docs. * moveed stein_bnn to other inf algs in docs * Added correct plating for model in `stein_bnn.py`. Works with latest #833. * Add some doctests to transforms (#1300) * add some doctest to transforms * make format * Tutorial for truncated distributions (#1272) * WIP Do not merge. Tutorial for truncated distributions * WIP: Completed a few todos and fixed a few typos * WIP: Completed main sections. References and part 5 still pending * Added section on built in distributions and folded distributions * Draft ready * Remove M1-related warning from cell output * Truncated distributions tutorial added to index * Wrap latex equations in double dollar sign * Fix broken markdown equations * Added more details on folded distribs. Re-arranged sections. * Test: Change title level. * Links now point to the docs instead of the source code. Fixed some broken formatting of the titles. Use different seeds for Prior/Inference/Prediction. Changed models for inferring the truncation. Fixed minor typos. * Install numpyro and upgrade jax, jaxlib and matplotlib Copy jax arrays before passing to matplotlib functions * Clarified statement about the log_prob method in the TruncatedDistribution class. * Changed intro sentence to include folded distributions. * Remove command for installing jax. Use np.unique instead of jnp.unique * Cast rate parameter to float (#1301) * Make potential_fn_gen and postprocess_fn_gen picklable (#1302) * add wrapper * Make potential_fn_gen postprocess_fn_gen pickable * Stein based inference (#833) * Added stein interface. * Fixed style and removed from VI baseclass. * Added reinit_guide.py * Added license. * added examples * Added examples. * Fixed some linting and LDA example; need to refactor wrapped_guide. * Added param site also get rng_keys; this should be reworked! * Removed datasets and fixed lda to running. * Fixed dimensionality bug for simplex support. * Added code from refactor/einstein * Fixed notebooks; todo: comment notebook. * Factored initialization of `kernels.RandomFeatureKernel` into `Stein.init` and updated `test_kernels.test_kernel_forward` accordingly. * Started testing. * Removed assert from test_init_strategy. * Skeleton test_stein.py * Updated `test_stein/test_init` * Added test_params and likelihood computation to lda. * Fix init in MixtureKernel * Notebook fixes * debugging log likelihood * WIP, move benchmarks to datasets * trace guide to compute likelihood in lda. * Debugging LDA * Removed test_vi.py (will use test_stein.py), added `test_stein.test_update_evaluate` * Cleaned test covered by `test_get_params`. * Added skeleton and finished _param_size test. * Fix LR example * IRIS LR * Fix Toy examples * Added pinfo test. * moved stein/test_kernels.py into stein/test_stein.py; updated `test_stein.test_apply_kernel` * Ran black and removed lambdas from KERNEL_TEST_CASE. * Added `test_stein.test_sp_mcmc` and removed calls to jnp.random.shuffle (deprecated). * Added skelelton test for test_score_sp_mcmc. * Fixed overwriting kval in `test_stein.test_apply_kernel` * Fixed lint * Fixed lint. * Added stein_loss test. * Factored vi source and test_vi out of einstein. * updated with black. * Figured out likelihood for LDA (need to change to compute likelihood instead of ELBO) * Added perplexity to LDA. * Fixed log position for perplexity. * Refactored callbacks and added `test_checkpoint`. * Fixed imports * Reverted LDA to working version. * Added callback tests * Return loss history for `stein.run` * Added visual to LDA. * Fixed return for `run_lda` * Added missing topic num 20. * Added todo * Cleaned 1d_mm stein notebook. * Updated 2d gaussian notebook. * Add description to SVGD. * Updated `RBF_kernel` to work with one particle and added kernels notebook. * Fixed bug in bandwidth of RBF_kernel * SVI reproducing result from SVGD paper. * Better learning rate for SVI. * larger network * Updated predictive to allow for particle methods. * Removed TODO and fixed learning rate. * EinStein out performance SVGD * Latest working. * Fixed VI without progressbar. * Fixed mini batching for VI. * Added kernel visualization. * Init to sample for bayesian networks. * TODO predict shape. * Added scaling to plate primitive. * Fixed enumeration in Stein and added subsample_scale to funsor.plate. * Debugging LDA * Debugging lda * Debugging merge. * Updated jacobian computation in Stein. * Fixed issue with nested parameters for stein grad. * Fixed issue with nested parameters for stein grad. * Added NLL to DMM and predictions. * renaming and removing benchmark code * Cleaning branch from benchmarking. * Removed prediction from DMM. * Changed to syntax from older python * Fixed lint. * Fixed reinit warning for `init_to_uniform`. * updated to use black[jupyter] * Added licenses. * Added smoke a smoke test for SteinVI * Factored out Stein point MCMC * Factored out VI from EinStein. * Updated stein_kernels.ipynb and removed debugging pred_prey. * Removed `mixture.py` use `mixtures.py` instead. * Fixed lint. * Added examples to docs build. * Fixed stale import in `hmc.py` docstring. * Removed stein point test cases. * Changed Predictive to only check for guided models with particles. * Fixed lint. * Changed `reinit_guide` to add rng_keys for reinitialization. * Added boston pricing dataset. Commented stein bnn example. Added `stein_bnn.py` to `test_examples.py`. * Removed empty line from `test_examples.py`. * Fixed lint. * Changed `stein_mixture_dmm.py` to use new signature and run method. * Added some comments and fixed `stein_mixture_dmm.py` to use new signature. * Fixed `event_shape` and `support` for `Sine`, `DoubleBanana`, and `Star` distributions in `stein_2d_dists.py`. * Removed notebooks from initial PR and updated stein_2d_toy.py to new run signature. * Parameterize `gru_dim` in `stein_mixture_dmm.py`. * Fixed steinvi to use #1263; TODO: update examples. * Removed init_with_noise. * removed stein_bnn and changed `examples/datasets.py` to upstream * removed examples * Removed stein examples from docs. * renamed einstein.utils to einstein.util. * updated testing * Changed test to use auto_guide `init_loc_fn`. * removed `numpyro/util/ravel_pytree` * removed unused imports in numpyro/util * Added initialization to kernels in `test_einstein_kernels.py` * Changed kernel test to use np.arrays at global level. * change jnp arrays to np np array in tests. reverted subsample scale. * added docstring to `einstein/util/batch_ravel_pytree` Co-authored-by: Ahmad Salim Al-Sibahi <[email protected]> Co-authored-by: einsteinvi <[email protected]> * Improve subsample warning keys (#1303) * Add ProvenanceArray to infer relational structure in a model (#1248) * Add provenance array * Add tests for provenance * run make format * Workaround not be able to eval_shape a distribution * Make license * add a clearer guide for render a model with scan * fix failing bugs in recent jax release * Fix further failing tests * Make sure to be able to render ImproperUniform and random initialized params * port get_dependencies to numpyro * tighten test_improper_normal bound (#1307) * Fix HMCECS multiple plates (#1305) * Add Kumaraswamy and relaxed Bernoulli distributions (#1283) * Add kumaraswamy and relaxed bernoulli distributions * clean up the flag * Require logits to be keyword argument * make relaxed bernoulli have the same signature as Pyro * fix docs build * Fix rsample bug * add more simple test for Kumaraswamy * Add various KL divergences for Gamma/Beta families (#1284) * Add new distributions and kl * Add kumaraswamy and relaxed bernoulli distributions * clean up the flag * Require logits to be keyword argument * make relaxed bernoulli have the same signature as Pyro * fix docs build * Fix rsample bug * move the flag to Kumaraswamy class for convenient * Add loose strategy for missing plates in MCMC (#1304) * Add loose strategy for MCMC * merge svi and mcmc plate warning strategies * fix failing tests * validate model accross ELBOs * update vae example * fix typos * Fix failing tests * skip prodlda test on CI * Bump to 0.9.0 (#1310) * Add loose strategy for MCMC * merge svi and mcmc plate warning strategies * fix failing tests * validate model accross ELBOs * update vae example * fix typos * Bump to version 0.9.0 * Fix failing tests * Fix warnings in tests/examples * relax funsor requirement * Move optax_to_numpyro to optim * skip prodlda test on CI * added dimensions to plate and sqrt precision. * fixed/added comments in stein_bnn.py and removed lr datasets. * added comment to stein_bnn.py * formatted files to black==22.1.0 Co-authored-by: Wataru Hashimoto <[email protected]> Co-authored-by: Omar Sosa Rodríguez <[email protected]> Co-authored-by: Vedran Hadziosmanovic <[email protected]> Co-authored-by: Du Phan <[email protected]> Co-authored-by: Ahmad Salim Al-Sibahi <[email protected]> Co-authored-by: einsteinvi <[email protected]> Co-authored-by: austereantelope <[email protected]>
Fixes #1276