From 3323cb79aa977c1c166f7fce0ecf8eb501e54568 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 14 Aug 2023 20:24:42 -0400 Subject: [PATCH 1/2] add vectorized_particles to ELBO --- numpyro/infer/elbo.py | 52 +++++++++++++++++++++++++++++++----------- test/infer/test_svi.py | 22 ++++++++++++++++++ 2 files changed, 61 insertions(+), 13 deletions(-) diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index b0a39a4bc..6ca8d008d 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -6,6 +6,7 @@ from operator import itemgetter import warnings +import jax from jax import eval_shape, random, vmap from jax.lax import stop_gradient import jax.numpy as jnp @@ -33,6 +34,8 @@ class ELBO: :param num_particles: The number of particles/samples used to form the ELBO (gradient) estimators. + :param vectorize_particles: Whether to use `jax.vmap` to obtain elbos over the + particles. If False, we will use `jax.lax.map`. Defaults to True. """ """ @@ -42,8 +45,9 @@ class ELBO: """ can_infer_discrete = False - def __init__(self, num_particles=1): + def __init__(self, num_particles=1, vectorize_particles=True): self.num_particles = num_particles + self.vectorize_particles = vectorize_particles def loss(self, rng_key, param_map, model, guide, *args, **kwargs): """ @@ -108,11 +112,10 @@ class Trace_ELBO(ELBO): :param num_particles: The number of particles/samples used to form the ELBO (gradient) estimators. + :param vectorize_particles: Whether to use `jax.vmap` to obtain elbos over the + particles. If False, we will use `jax.lax.map`. Defaults to True. """ - def __init__(self, num_particles=1): - self.num_particles = num_particles - def loss_with_mutable_state( self, rng_key, param_map, model, guide, *args, **kwargs ): @@ -163,7 +166,10 @@ def single_particle_elbo(rng_key): return {"loss": -elbo, "mutable_state": mutable_state} else: rng_keys = random.split(rng_key, self.num_particles) - elbos, mutable_state = vmap(single_particle_elbo)(rng_keys) + if self.vectorize_particles: + elbos, mutable_state = vmap(single_particle_elbo)(rng_keys) + else: + elbos, mutable_state = jax.lax.map(single_particle_elbo, rng_keys) return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state} @@ -291,7 +297,10 @@ def single_particle_elbo(rng_key): return {"loss": -elbo, "mutable_state": mutable_state} else: rng_keys = random.split(rng_key, self.num_particles) - elbos, mutable_state = vmap(single_particle_elbo)(rng_keys) + if self.vectorize_particles: + elbos, mutable_state = vmap(single_particle_elbo)(rng_keys) + else: + elbos, mutable_state = jax.lax.map(single_particle_elbo, rng_keys) return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state} @@ -311,6 +320,8 @@ class RenyiELBO(ELBO): Here :math:`\alpha \neq 1`. Default is 0. :param num_particles: The number of particles/samples used to form the objective (gradient) estimator. Default is 2. + :param vectorize_particles: Whether to use `jax.vmap` to obtain elbos over the + particles. If False, we will use `jax.lax.map`. Defaults to True. Example:: @@ -427,7 +438,10 @@ def loss(self, rng_key, param_map, model, guide, *args, **kwargs): ) rng_keys = random.split(rng_key, self.num_particles) - elbos, common_plate_scale = vmap(single_particle_elbo)(rng_keys) + if self.vectorize_particles: + elbos, common_plate_scale = vmap(single_particle_elbo)(rng_keys) + else: + elbos, common_plate_scale = jax.lax.map(single_particle_elbo, rng_keys) assert common_plate_scale.shape == (self.num_particles,) assert elbos.shape[0] == self.num_particles scaled_elbos = (1.0 - self.alpha) * elbos @@ -695,8 +709,10 @@ class TraceGraph_ELBO(ELBO): can_infer_discrete = True - def __init__(self, num_particles=1): - super().__init__(num_particles=num_particles) + def __init__(self, num_particles=1, vectorize_particles=True): + super().__init__( + num_particles=num_particles, vectorize_particles=vectorize_particles + ) def loss(self, rng_key, param_map, model, guide, *args, **kwargs): """ @@ -771,7 +787,10 @@ def single_particle_elbo(rng_key): return -single_particle_elbo(rng_key) else: rng_keys = random.split(rng_key, self.num_particles) - return -jnp.mean(vmap(single_particle_elbo)(rng_keys)) + if self.vectorize_particles: + return -jnp.mean(vmap(single_particle_elbo)(rng_keys)) + else: + return -jnp.mean(jax.lax.map(single_particle_elbo, rng_keys)) def get_importance_trace_enum( @@ -953,9 +972,13 @@ class TraceEnum_ELBO(ELBO): can_infer_discrete = True - def __init__(self, num_particles=1, max_plate_nesting=float("inf")): + def __init__( + self, num_particles=1, max_plate_nesting=float("inf"), vectorize_particles=True + ): self.max_plate_nesting = max_plate_nesting - super().__init__(num_particles=num_particles) + super().__init__( + num_particles=num_particles, vectorize_particles=vectorize_particles + ) def loss(self, rng_key, param_map, model, guide, *args, **kwargs): def single_particle_elbo(rng_key): @@ -1128,4 +1151,7 @@ def single_particle_elbo(rng_key): return -single_particle_elbo(rng_key) else: rng_keys = random.split(rng_key, self.num_particles) - return -jnp.mean(vmap(single_particle_elbo)(rng_keys)) + if self.vectorize_particles: + return -jnp.mean(vmap(single_particle_elbo)(rng_keys)) + else: + return -jnp.mean(jax.lax.map(single_particle_elbo, rng_keys)) diff --git a/test/infer/test_svi.py b/test/infer/test_svi.py index fd6f8ab75..e834a9758 100644 --- a/test/infer/test_svi.py +++ b/test/infer/test_svi.py @@ -170,6 +170,28 @@ def get_renyi(n=N, k=K, fix_indices=True): assert_allclose(atol, 0.0, atol=1e-5) +def test_vectorized_particle(): + data = jnp.array([1.0] * 8 + [0.0] * 2) + + def model(data): + f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) + with numpyro.plate("N", len(data)): + numpyro.sample("obs", dist.Bernoulli(f), obs=data) + + def guide(data): + alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) + beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) + numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) + + vmap_results = SVI( + model, guide, optim.Adam(0.1), Trace_ELBO(vectorize_particles=True) + ).run(random.PRNGKey(0), 100, data) + map_results = SVI( + model, guide, optim.Adam(0.1), Trace_ELBO(vectorize_particles=False) + ).run(random.PRNGKey(0), 100, data) + assert_allclose(vmap_results.losses, map_results.losses, atol=1e-5) + + @pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)]) @pytest.mark.parametrize("optimizer", [optim.Adam(0.01), optimizers.adam(0.01)]) def test_beta_bernoulli(elbo, optimizer): From e4a89c2966b4d1196882807941212e93f03943cb Mon Sep 17 00:00:00 2001 From: Du Phan Date: Tue, 15 Aug 2023 11:45:05 -0400 Subject: [PATCH 2/2] address comments --- numpyro/infer/elbo.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index 6ca8d008d..127a616fe 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -34,8 +34,9 @@ class ELBO: :param num_particles: The number of particles/samples used to form the ELBO (gradient) estimators. - :param vectorize_particles: Whether to use `jax.vmap` to obtain elbos over the - particles. If False, we will use `jax.lax.map`. Defaults to True. + :param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the + num_particles-many particles in parallel. If False use `jax.lax.map`. + Defaults to True. """ """ @@ -112,8 +113,9 @@ class Trace_ELBO(ELBO): :param num_particles: The number of particles/samples used to form the ELBO (gradient) estimators. - :param vectorize_particles: Whether to use `jax.vmap` to obtain elbos over the - particles. If False, we will use `jax.lax.map`. Defaults to True. + :param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the + num_particles-many particles in parallel. If False use `jax.lax.map`. + Defaults to True. """ def loss_with_mutable_state( @@ -320,8 +322,9 @@ class RenyiELBO(ELBO): Here :math:`\alpha \neq 1`. Default is 0. :param num_particles: The number of particles/samples used to form the objective (gradient) estimator. Default is 2. - :param vectorize_particles: Whether to use `jax.vmap` to obtain elbos over the - particles. If False, we will use `jax.lax.map`. Defaults to True. + :param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the + num_particles-many particles in parallel. If False use `jax.lax.map`. + Defaults to True. Example::