-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add vectorized_particles to ELBO #1624
Merged
Merged
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
from operator import itemgetter | ||
import warnings | ||
|
||
import jax | ||
from jax import eval_shape, random, vmap | ||
from jax.lax import stop_gradient | ||
import jax.numpy as jnp | ||
|
@@ -33,6 +34,8 @@ class ELBO: | |
|
||
:param num_particles: The number of particles/samples used to form the ELBO | ||
(gradient) estimators. | ||
:param vectorize_particles: Whether to use `jax.vmap` to obtain elbos over the | ||
particles. If False, we will use `jax.lax.map`. Defaults to True. | ||
""" | ||
|
||
""" | ||
|
@@ -42,8 +45,9 @@ class ELBO: | |
""" | ||
can_infer_discrete = False | ||
|
||
def __init__(self, num_particles=1): | ||
def __init__(self, num_particles=1, vectorize_particles=True): | ||
self.num_particles = num_particles | ||
self.vectorize_particles = vectorize_particles | ||
|
||
def loss(self, rng_key, param_map, model, guide, *args, **kwargs): | ||
""" | ||
|
@@ -108,11 +112,10 @@ class Trace_ELBO(ELBO): | |
|
||
:param num_particles: The number of particles/samples used to form the ELBO | ||
(gradient) estimators. | ||
:param vectorize_particles: Whether to use `jax.vmap` to obtain elbos over the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
particles. If False, we will use `jax.lax.map`. Defaults to True. | ||
""" | ||
|
||
def __init__(self, num_particles=1): | ||
self.num_particles = num_particles | ||
|
||
def loss_with_mutable_state( | ||
self, rng_key, param_map, model, guide, *args, **kwargs | ||
): | ||
|
@@ -163,7 +166,10 @@ def single_particle_elbo(rng_key): | |
return {"loss": -elbo, "mutable_state": mutable_state} | ||
else: | ||
rng_keys = random.split(rng_key, self.num_particles) | ||
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys) | ||
if self.vectorize_particles: | ||
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys) | ||
else: | ||
elbos, mutable_state = jax.lax.map(single_particle_elbo, rng_keys) | ||
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state} | ||
|
||
|
||
|
@@ -291,7 +297,10 @@ def single_particle_elbo(rng_key): | |
return {"loss": -elbo, "mutable_state": mutable_state} | ||
else: | ||
rng_keys = random.split(rng_key, self.num_particles) | ||
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys) | ||
if self.vectorize_particles: | ||
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys) | ||
else: | ||
elbos, mutable_state = jax.lax.map(single_particle_elbo, rng_keys) | ||
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state} | ||
|
||
|
||
|
@@ -311,6 +320,8 @@ class RenyiELBO(ELBO): | |
Here :math:`\alpha \neq 1`. Default is 0. | ||
:param num_particles: The number of particles/samples | ||
used to form the objective (gradient) estimator. Default is 2. | ||
:param vectorize_particles: Whether to use `jax.vmap` to obtain elbos over the | ||
particles. If False, we will use `jax.lax.map`. Defaults to True. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
|
||
Example:: | ||
|
||
|
@@ -427,7 +438,10 @@ def loss(self, rng_key, param_map, model, guide, *args, **kwargs): | |
) | ||
|
||
rng_keys = random.split(rng_key, self.num_particles) | ||
elbos, common_plate_scale = vmap(single_particle_elbo)(rng_keys) | ||
if self.vectorize_particles: | ||
elbos, common_plate_scale = vmap(single_particle_elbo)(rng_keys) | ||
else: | ||
elbos, common_plate_scale = jax.lax.map(single_particle_elbo, rng_keys) | ||
assert common_plate_scale.shape == (self.num_particles,) | ||
assert elbos.shape[0] == self.num_particles | ||
scaled_elbos = (1.0 - self.alpha) * elbos | ||
|
@@ -695,8 +709,10 @@ class TraceGraph_ELBO(ELBO): | |
|
||
can_infer_discrete = True | ||
|
||
def __init__(self, num_particles=1): | ||
super().__init__(num_particles=num_particles) | ||
def __init__(self, num_particles=1, vectorize_particles=True): | ||
super().__init__( | ||
num_particles=num_particles, vectorize_particles=vectorize_particles | ||
) | ||
|
||
def loss(self, rng_key, param_map, model, guide, *args, **kwargs): | ||
""" | ||
|
@@ -771,7 +787,10 @@ def single_particle_elbo(rng_key): | |
return -single_particle_elbo(rng_key) | ||
else: | ||
rng_keys = random.split(rng_key, self.num_particles) | ||
return -jnp.mean(vmap(single_particle_elbo)(rng_keys)) | ||
if self.vectorize_particles: | ||
return -jnp.mean(vmap(single_particle_elbo)(rng_keys)) | ||
else: | ||
return -jnp.mean(jax.lax.map(single_particle_elbo, rng_keys)) | ||
|
||
|
||
def get_importance_trace_enum( | ||
|
@@ -953,9 +972,13 @@ class TraceEnum_ELBO(ELBO): | |
|
||
can_infer_discrete = True | ||
|
||
def __init__(self, num_particles=1, max_plate_nesting=float("inf")): | ||
def __init__( | ||
self, num_particles=1, max_plate_nesting=float("inf"), vectorize_particles=True | ||
): | ||
self.max_plate_nesting = max_plate_nesting | ||
super().__init__(num_particles=num_particles) | ||
super().__init__( | ||
num_particles=num_particles, vectorize_particles=vectorize_particles | ||
) | ||
|
||
def loss(self, rng_key, param_map, model, guide, *args, **kwargs): | ||
def single_particle_elbo(rng_key): | ||
|
@@ -1128,4 +1151,7 @@ def single_particle_elbo(rng_key): | |
return -single_particle_elbo(rng_key) | ||
else: | ||
rng_keys = random.split(rng_key, self.num_particles) | ||
return -jnp.mean(vmap(single_particle_elbo)(rng_keys)) | ||
if self.vectorize_particles: | ||
return -jnp.mean(vmap(single_particle_elbo)(rng_keys)) | ||
else: | ||
return -jnp.mean(jax.lax.map(single_particle_elbo, rng_keys)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this only for eval or also for training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is for both, but typically used for eval, when we require a large number of particles.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: