-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add vectorized_particles to ELBO #1624
Conversation
numpyro/infer/elbo.py
Outdated
@@ -33,6 +34,8 @@ class ELBO: | |||
|
|||
:param num_particles: The number of particles/samples used to form the ELBO | |||
(gradient) estimators. | |||
:param vectorize_particles: Whether to use `jax.vmap` to obtain elbos over the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this only for eval or also for training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is for both, but typically used for eval, when we require a large number of particles.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
:param vectorize_particles: Whether to use
jax.vmap
to compute ELBOs over thenum_particles
-many particles in parallel. If False usejax.lax.map
. Defaults to True.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
numpyro/infer/elbo.py
Outdated
@@ -33,6 +34,8 @@ class ELBO: | |||
|
|||
:param num_particles: The number of particles/samples used to form the ELBO | |||
(gradient) estimators. | |||
:param vectorize_particles: Whether to use `jax.vmap` to obtain elbos over the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
:param vectorize_particles: Whether to use
jax.vmap
to compute ELBOs over thenum_particles
-many particles in parallel. If False usejax.lax.map
. Defaults to True.
numpyro/infer/elbo.py
Outdated
@@ -108,11 +112,10 @@ class Trace_ELBO(ELBO): | |||
|
|||
:param num_particles: The number of particles/samples used to form the ELBO | |||
(gradient) estimators. | |||
:param vectorize_particles: Whether to use `jax.vmap` to obtain elbos over the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
numpyro/infer/elbo.py
Outdated
@@ -311,6 +320,8 @@ class RenyiELBO(ELBO): | |||
Here :math:`\alpha \neq 1`. Default is 0. | |||
:param num_particles: The number of particles/samples | |||
used to form the objective (gradient) estimator. Default is 2. | |||
:param vectorize_particles: Whether to use `jax.vmap` to obtain elbos over the | |||
particles. If False, we will use `jax.lax.map`. Defaults to True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
Allow to use lax.map instead of vmap in ELBO to reduce memory requirement