Skip to content

Commit

Permalink
vmap-able Distributions (#1529)
Browse files Browse the repository at this point in the history
* Make `Normal` objects compatible with `jax.vmap`

* Sort imports

* Make `MultivariateNormal` compatible with `vmap`

* remove unused variables

* Pass `in_vmap` to base class constructor of mvnormal dists

* fix issue with docs formatting

* test vmapping multivariate normal twice

* temporarily disable failing test

* lint

* don't use `__init__` to restore state post-vmapping

* adapt numpyro control flows to new batch_shape handling

* fixup! adapt numpyro control flows to new batch_shape handling

* Do not mutate shapes of ExpandedDistribution for map-free ops

* improve post-scan batch_shape updating

* re-enable disable test

* [WIP] vmap tests for arbitrary distributions

* WIP

* vmappable continuous distributions

* Fix incorrect unflattenning of inverse transforms

* mark CAR's adj_matrix as auxiliary if sparse

* vmap support for discrete distributions

* add missing license header

* batch shape promotion for Bernoulli/Categorical Probs

* More distribution-specific logic batch-shape promotion

* linting

* fixup! linting

* fixup! linting

* implement `vmap` support remaining distribution

* reference sparse warning using correct namespace

* uniformize tree_flatten/unflatten method across dists

* fixup! uniformize tree_flatten/unflatten method across dists

* remove normal-specific vmap tests

* decentralize batch_shape promotion

* fixup! uniformize tree_flatten/unflatten method across dists

* fixup! uniformize tree_flatten/unflatten method across dists

* fixup! uniformize tree_flatten/unflatten method across dists

* move `vmap_over` in a dedicated util module

* minor cosmetic changes in `vmap_util`

* [WIP] clarify test_vmap_dist

* Finish clarifying test_vmap_dist

* remove spurious deleted/added newlines

* minor improvements in tests

* use arg_constraints for batch shape promotion

* move batch shape promotion to vmap_util.py

* vmap_util.py -> batch_util.py

* have pytree_data_fields default to arg_constraint.keys()

* revert some unrelated modifications
  • Loading branch information
pierreglaser authored Jul 30, 2023
1 parent 4799b84 commit 902623c
Show file tree
Hide file tree
Showing 11 changed files with 1,026 additions and 389 deletions.
3 changes: 3 additions & 0 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from jax.tree_util import tree_flatten, tree_map, tree_unflatten

from numpyro import handlers
from numpyro.distributions.batch_util import promote_batch_shape
from numpyro.ops.pytree import PytreeTrace
from numpyro.primitives import _PYRO_STACK, Messenger, apply_stack
from numpyro.util import not_jax_tracer
Expand Down Expand Up @@ -220,6 +221,7 @@ def body_fn(wrapped_carry, x, prefix=None):
if first_var is None:
first_var = name

site["fn"] = promote_batch_shape(site["fn"])
# we haven't promote shapes of values yet during `lax.scan`, so we do it here
site["value"] = _promote_scanned_value_shapes(site["value"], site["fn"])

Expand Down Expand Up @@ -308,6 +310,7 @@ def body_fn(wrapped_carry, x):
for name, site in pytree_trace.trace.items():
if site["type"] != "sample":
continue
site["fn"] = promote_batch_shape(site["fn"])
# we haven't promote shapes of values yet during `lax.scan`, so we do it here
site["value"] = _promote_scanned_value_shapes(site["value"], site["fn"])
return last_carry, (pytree_trace, ys)
Expand Down
Loading

0 comments on commit 902623c

Please sign in to comment.