Skip to content

Commit

Permalink
add AutoSemiDAIS (#1426)
Browse files Browse the repository at this point in the history
* autosemidais skeleton

* sketch implementation of semidais

* cleanup; more docstring

* add broken test

* Enhance the prototype

* Undo keep_plate in log_density

* fatten test

* move local computations entirely into plate; add broken smoke test

* use create plates for autonormal

* add test_autosemidais_inadmissible_smoke

* expand test_autosemidais_admissible_smoke

* add missing self.prefix

* Fix unravel dimensions

* simplify test; use optax

* tweaktest

* comments

* Not rescaling grad

* Fix lint

* add pyyaml dependency for flax

* looks correct but dais_elbo8 != dais_elbo16 and mf_elbo is better than dais_elbo

* Fix bug of unconstrain reparam

* cleanup test

* improve docstring

* remove comments; revert block change; improve docstring

* support sample_shape in sample_posterior

* add to docs

* fix rst

* attempt to fix docs

* Fix the remaining TODO

Co-authored-by: Du Phan <[email protected]>
Co-authored-by: Du Phan <[email protected]>
  • Loading branch information
3 people authored Jun 17, 2022
1 parent 392c1fd commit acb2cd8
Show file tree
Hide file tree
Showing 7 changed files with 608 additions and 5 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete l
- [AutoDelta](https://num.pyro.ai/en/latest/autoguide.html#autodelta) is used for computing point estimates via MAP (maximum a posteriori estimation). See [here](https://github.com/pyro-ppl/numpyro/blob/bbe1f879eede79eebfdd16dfc49c77c4d1fc727c/examples/zero_inflated_poisson.py#L101) for example usage.
- [AutoBNAFNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoBNAFNormal) and [AutoIAFNormal](https://num.pyro.ai/en/latest/autoguide.html#autoiafnormal) offer flexible variational distributions parameterized by normalizing flows.
- [AutoDAIS](https://num.pyro.ai/en/latest/autoguide.html#autodais) is a powerful variational inference algorithm that leverages HMC. It can be a good choice for dealing with highly correlated posteriors but may be computationally expensive depending on the nature of the model.
- [AutoSemiDAIS](https://num.pyro.ai/en/latest/autoguide.html#autosemidais) constructs a posterior approximation like [AutoDAIS](https://num.pyro.ai/en/latest/autoguide.html#autodais) for local latent variables but provides support for data subsampling during ELBO training by utilizing a parametric guide for global latent variables.
- [AutoLaplaceApproximation](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoLaplaceApproximation) can be used to compute a Laplace approximation.

### Stein Variational Inference
Expand Down
9 changes: 9 additions & 0 deletions docs/source/autoguide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ We provide a brief overview of the automatically generated guides available in N
* `AutoDelta <https://num.pyro.ai/en/latest/autoguide.html#autodelta>`_ is used for computing point estimates via MAP (maximum a posteriori estimation). See `here <https://github.com/pyro-ppl/numpyro/blob/bbe1f879eede79eebfdd16dfc49c77c4d1fc727c/examples/zero_inflated_poisson.py#L101>`_ for example usage.
* `AutoBNAFNormal <https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoBNAFNormal>`_ and `AutoIAFNormal <https://num.pyro.ai/en/latest/autoguide.html#autoiafnormal>`_ offer flexible variational distributions parameterized by normalizing flows.
* `AutoDAIS <https://num.pyro.ai/en/latest/autoguide.html#autodais>`_ is a powerful variational inference algorithm that leverages HMC. It can be a good choice for dealing with highly correlated posteriors but may be computationally expensive depending on the nature of the model.
* `AutoSemiDAIS <https://num.pyro.ai/en/latest/autoguide.html#autosemidais>`_ constructs a posterior approximation like `AutoDAIS <https://num.pyro.ai/en/latest/autoguide.html#autodais>`_ for local latent variables but provides support for data subsampling during ELBO training by utilizing a parametric guide for global latent variables.
* `AutoLaplaceApproximation <https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoLaplaceApproximation>`_ can be used to compute a Laplace approximation.

.. automodule:: numpyro.infer.autoguide
Expand Down Expand Up @@ -99,3 +100,11 @@ AutoDAIS
:undoc-members:
:show-inheritance:
:member-order: bysource

AutoSemiDAIS
------------
.. autoclass:: numpyro.infer.autoguide.AutoSemiDAIS
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
3 changes: 2 additions & 1 deletion numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,8 @@ def sample_with_intermediates(self, key, sample_shape=()):
def sample(self, key, sample_shape=()):
return self.sample_with_intermediates(key, sample_shape)[0]

def log_prob(self, value):
def log_prob(self, value, intermediates=None):
# TODO: utilize `intermediates`
shape = lax.broadcast_shapes(
self.batch_shape,
jnp.shape(value)[: max(jnp.ndim(value) - self.event_dim, 0)],
Expand Down
Loading

0 comments on commit acb2cd8

Please sign in to comment.