From 403192bfd3549185d4963012002e679b20a8a0be Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 20 Oct 2023 12:52:26 -0400 Subject: [PATCH] support multi_sample_guide in Trace_ELBO --- numpyro/infer/elbo.py | 13 ++++++++++--- numpyro/infer/svi.py | 22 +++++----------------- test/infer/test_svi.py | 13 ++++++++----- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index c6ff973ea..150003507 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -45,6 +45,7 @@ class ELBO: Subclasses that are capable of inferring discrete latent variables should override to `True` """ can_infer_discrete = False + multi_sample_guide = False def __init__(self, num_particles=1, vectorize_particles=True): self.num_particles = num_particles @@ -57,7 +58,6 @@ def loss( model, guide, *args, - multi_sample_guide=False, **kwargs, ): """ @@ -127,6 +127,14 @@ class Trace_ELBO(ELBO): Defaults to True. """ + def __init__( + self, num_particles=1, vectorize_particles=True, multi_sample_guide=False + ): + self.multi_sample_guide = multi_sample_guide + super().__init__( + num_particles=num_particles, vectorize_particles=vectorize_particles + ) + def loss_with_mutable_state( self, rng_key, @@ -134,7 +142,6 @@ def loss_with_mutable_state( model, guide, *args, - multi_sample_guide=False, **kwargs, ): def single_particle_elbo(rng_key): @@ -150,7 +157,7 @@ def single_particle_elbo(rng_key): if site["type"] == "mutable" } params.update(mutable_params) - if multi_sample_guide: + if self.multi_sample_guide: plates = { name: site["value"] for name, site in guide_trace.items() diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index f47efce7a..0f6a7db95 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -132,15 +132,12 @@ class SVI(object): :return: tuple of `(init_fn, update_fn, evaluate)`. """ - def __init__( - self, model, guide, optim, loss, multi_sample_guide=False, **static_kwargs - ): + def __init__(self, model, guide, optim, loss, **static_kwargs): self.model = model self.guide = guide self.loss = loss self.static_kwargs = static_kwargs self.constrain_fn = None - self.multi_sample_guide = multi_sample_guide if isinstance(optim, _NumPyroOptim): self.optim = optim @@ -193,7 +190,7 @@ def init(self, rng_key, *args, init_params=None, **kwargs): } if init_params is not None: init_guide_params.update(init_params) - if self.multi_sample_guide: + if self.loss.multi_sample_guide: latents = { name: site["value"][0] for name, site in guide_trace.items() @@ -272,9 +269,6 @@ def update(self, svi_state, *args, **kwargs): :return: tuple of `(svi_state, loss)`. """ rng_key, rng_key_step = random.split(svi_state.rng_key) - static_kwargs = self.static_kwargs.copy() - if self.multi_sample_guide: - static_kwargs["multi_sample_guide"] = True loss_fn = _make_loss_fn( self.loss, rng_key_step, @@ -283,7 +277,7 @@ def update(self, svi_state, *args, **kwargs): self.guide, args, kwargs, - static_kwargs, + self.static_kwargs, mutable_state=svi_state.mutable_state, ) (loss_val, mutable_state), optim_state = self.optim.eval_and_update( @@ -304,9 +298,6 @@ def stable_update(self, svi_state, *args, **kwargs): :return: tuple of `(svi_state, loss)`. """ rng_key, rng_key_step = random.split(svi_state.rng_key) - static_kwargs = self.static_kwargs.copy() - if self.multi_sample_guide: - static_kwargs["multi_sample_guide"] = True loss_fn = _make_loss_fn( self.loss, rng_key_step, @@ -315,7 +306,7 @@ def stable_update(self, svi_state, *args, **kwargs): self.guide, args, kwargs, - static_kwargs, + self.static_kwargs, mutable_state=svi_state.mutable_state, ) (loss_val, mutable_state), optim_state = self.optim.eval_and_stable_update( @@ -428,9 +419,6 @@ def evaluate(self, svi_state, *args, **kwargs): # we split to have the same seed as `update_fn` given an svi_state _, rng_key_eval = random.split(svi_state.rng_key) params = self.get_params(svi_state) - static_kwargs = self.static_kwargs.copy() - if self.multi_sample_guide: - static_kwargs["multi_sample_guide"] = True return self.loss.loss( rng_key_eval, params, @@ -438,5 +426,5 @@ def evaluate(self, svi_state, *args, **kwargs): self.guide, *args, **kwargs, - **static_kwargs, + **self.static_kwargs, ) diff --git a/test/infer/test_svi.py b/test/infer/test_svi.py index c48bca566..59e21cadd 100644 --- a/test/infer/test_svi.py +++ b/test/infer/test_svi.py @@ -741,16 +741,19 @@ def guide(difficulty=0.0): def test_multi_sample_guide(): + actual_loc = 3.0 + actual_scale = 2.0 + def model(): - numpyro.sample("x", dist.Normal(2, 3)) + numpyro.sample("x", dist.Normal(actual_loc, actual_scale)) def guide(): loc = numpyro.param("loc", 0.0) scale = numpyro.param("scale", 1.0, constraint=constraints.positive) numpyro.sample("x", dist.Normal(loc, scale).expand([10])) - svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), multi_sample_guide=True) - svi_results = svi.run(random.PRNGKey(0), 1000) + svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(multi_sample_guide=True)) + svi_results = svi.run(random.PRNGKey(0), 2000) params = svi_results.params - assert_allclose(params["loc"], 2.0, rtol=0.1) - assert_allclose(params["scale"], 3.0, rtol=0.1) + assert_allclose(params["loc"], actual_loc, rtol=0.1) + assert_allclose(params["scale"], actual_scale, rtol=0.1)