Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Make `Normal` objects compatible with `jax.vmap` * Sort imports * Make `MultivariateNormal` compatible with `vmap` * remove unused variables * Pass `in_vmap` to base class constructor of mvnormal dists * fix issue with docs formatting * test vmapping multivariate normal twice * temporarily disable failing test * lint * don't use `__init__` to restore state post-vmapping * adapt numpyro control flows to new batch_shape handling * fixup! adapt numpyro control flows to new batch_shape handling * Do not mutate shapes of ExpandedDistribution for map-free ops * improve post-scan batch_shape updating * re-enable disable test * [WIP] vmap tests for arbitrary distributions * WIP * vmappable continuous distributions * Fix incorrect unflattenning of inverse transforms * mark CAR's adj_matrix as auxiliary if sparse * vmap support for discrete distributions * add missing license header * batch shape promotion for Bernoulli/Categorical Probs * More distribution-specific logic batch-shape promotion * linting * fixup! linting * fixup! linting * implement `vmap` support remaining distribution * reference sparse warning using correct namespace * uniformize tree_flatten/unflatten method across dists * fixup! uniformize tree_flatten/unflatten method across dists * remove normal-specific vmap tests * decentralize batch_shape promotion * fixup! uniformize tree_flatten/unflatten method across dists * fixup! uniformize tree_flatten/unflatten method across dists * fixup! uniformize tree_flatten/unflatten method across dists * move `vmap_over` in a dedicated util module * minor cosmetic changes in `vmap_util` * [WIP] clarify test_vmap_dist * Finish clarifying test_vmap_dist * remove spurious deleted/added newlines * minor improvements in tests * use arg_constraints for batch shape promotion * move batch shape promotion to vmap_util.py * vmap_util.py -> batch_util.py * have pytree_data_fields default to arg_constraint.keys() * revert some unrelated modifications
- Loading branch information