Skip to content

Commit

Permalink
Add vectorized_particles to ELBO (#1624)
Browse files Browse the repository at this point in the history
* add vectorized_particles to ELBO

* address comments
  • Loading branch information
fehiepsi authored Aug 15, 2023
1 parent 4e37df3 commit 56b88c3
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 13 deletions.
55 changes: 42 additions & 13 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from operator import itemgetter
import warnings

import jax
from jax import eval_shape, random, vmap
from jax.lax import stop_gradient
import jax.numpy as jnp
Expand Down Expand Up @@ -33,6 +34,9 @@ class ELBO:
:param num_particles: The number of particles/samples used to form the ELBO
(gradient) estimators.
:param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
num_particles-many particles in parallel. If False use `jax.lax.map`.
Defaults to True.
"""

"""
Expand All @@ -42,8 +46,9 @@ class ELBO:
"""
can_infer_discrete = False

def __init__(self, num_particles=1):
def __init__(self, num_particles=1, vectorize_particles=True):
self.num_particles = num_particles
self.vectorize_particles = vectorize_particles

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
"""
Expand Down Expand Up @@ -108,11 +113,11 @@ class Trace_ELBO(ELBO):
:param num_particles: The number of particles/samples used to form the ELBO
(gradient) estimators.
:param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
num_particles-many particles in parallel. If False use `jax.lax.map`.
Defaults to True.
"""

def __init__(self, num_particles=1):
self.num_particles = num_particles

def loss_with_mutable_state(
self, rng_key, param_map, model, guide, *args, **kwargs
):
Expand Down Expand Up @@ -163,7 +168,10 @@ def single_particle_elbo(rng_key):
return {"loss": -elbo, "mutable_state": mutable_state}
else:
rng_keys = random.split(rng_key, self.num_particles)
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys)
if self.vectorize_particles:
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys)
else:
elbos, mutable_state = jax.lax.map(single_particle_elbo, rng_keys)
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state}


Expand Down Expand Up @@ -291,7 +299,10 @@ def single_particle_elbo(rng_key):
return {"loss": -elbo, "mutable_state": mutable_state}
else:
rng_keys = random.split(rng_key, self.num_particles)
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys)
if self.vectorize_particles:
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys)
else:
elbos, mutable_state = jax.lax.map(single_particle_elbo, rng_keys)
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state}


Expand All @@ -311,6 +322,9 @@ class RenyiELBO(ELBO):
Here :math:`\alpha \neq 1`. Default is 0.
:param num_particles: The number of particles/samples
used to form the objective (gradient) estimator. Default is 2.
:param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
num_particles-many particles in parallel. If False use `jax.lax.map`.
Defaults to True.
Example::
Expand Down Expand Up @@ -427,7 +441,10 @@ def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
)

rng_keys = random.split(rng_key, self.num_particles)
elbos, common_plate_scale = vmap(single_particle_elbo)(rng_keys)
if self.vectorize_particles:
elbos, common_plate_scale = vmap(single_particle_elbo)(rng_keys)
else:
elbos, common_plate_scale = jax.lax.map(single_particle_elbo, rng_keys)
assert common_plate_scale.shape == (self.num_particles,)
assert elbos.shape[0] == self.num_particles
scaled_elbos = (1.0 - self.alpha) * elbos
Expand Down Expand Up @@ -695,8 +712,10 @@ class TraceGraph_ELBO(ELBO):

can_infer_discrete = True

def __init__(self, num_particles=1):
super().__init__(num_particles=num_particles)
def __init__(self, num_particles=1, vectorize_particles=True):
super().__init__(
num_particles=num_particles, vectorize_particles=vectorize_particles
)

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
"""
Expand Down Expand Up @@ -771,7 +790,10 @@ def single_particle_elbo(rng_key):
return -single_particle_elbo(rng_key)
else:
rng_keys = random.split(rng_key, self.num_particles)
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))
if self.vectorize_particles:
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))
else:
return -jnp.mean(jax.lax.map(single_particle_elbo, rng_keys))


def get_importance_trace_enum(
Expand Down Expand Up @@ -953,9 +975,13 @@ class TraceEnum_ELBO(ELBO):

can_infer_discrete = True

def __init__(self, num_particles=1, max_plate_nesting=float("inf")):
def __init__(
self, num_particles=1, max_plate_nesting=float("inf"), vectorize_particles=True
):
self.max_plate_nesting = max_plate_nesting
super().__init__(num_particles=num_particles)
super().__init__(
num_particles=num_particles, vectorize_particles=vectorize_particles
)

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
def single_particle_elbo(rng_key):
Expand Down Expand Up @@ -1128,4 +1154,7 @@ def single_particle_elbo(rng_key):
return -single_particle_elbo(rng_key)
else:
rng_keys = random.split(rng_key, self.num_particles)
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))
if self.vectorize_particles:
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))
else:
return -jnp.mean(jax.lax.map(single_particle_elbo, rng_keys))
22 changes: 22 additions & 0 deletions test/infer/test_svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,28 @@ def get_renyi(n=N, k=K, fix_indices=True):
assert_allclose(atol, 0.0, atol=1e-5)


def test_vectorized_particle():
data = jnp.array([1.0] * 8 + [0.0] * 2)

def model(data):
f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
with numpyro.plate("N", len(data)):
numpyro.sample("obs", dist.Bernoulli(f), obs=data)

def guide(data):
alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive)
beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

vmap_results = SVI(
model, guide, optim.Adam(0.1), Trace_ELBO(vectorize_particles=True)
).run(random.PRNGKey(0), 100, data)
map_results = SVI(
model, guide, optim.Adam(0.1), Trace_ELBO(vectorize_particles=False)
).run(random.PRNGKey(0), 100, data)
assert_allclose(vmap_results.losses, map_results.losses, atol=1e-5)


@pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)])
@pytest.mark.parametrize("optimizer", [optim.Adam(0.01), optimizers.adam(0.01)])
def test_beta_bernoulli(elbo, optimizer):
Expand Down

0 comments on commit 56b88c3

Please sign in to comment.