Skip to content

Commit

Permalink
support multi_sample_guide in Trace_ELBO
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Oct 20, 2023
1 parent f4592b6 commit 403192b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 25 deletions.
13 changes: 10 additions & 3 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class ELBO:
Subclasses that are capable of inferring discrete latent variables should override to `True`
"""
can_infer_discrete = False
multi_sample_guide = False

def __init__(self, num_particles=1, vectorize_particles=True):
self.num_particles = num_particles
Expand All @@ -57,7 +58,6 @@ def loss(
model,
guide,
*args,
multi_sample_guide=False,
**kwargs,
):
"""
Expand Down Expand Up @@ -127,14 +127,21 @@ class Trace_ELBO(ELBO):
Defaults to True.
"""

def __init__(
self, num_particles=1, vectorize_particles=True, multi_sample_guide=False
):
self.multi_sample_guide = multi_sample_guide
super().__init__(
num_particles=num_particles, vectorize_particles=vectorize_particles
)

def loss_with_mutable_state(
self,
rng_key,
param_map,
model,
guide,
*args,
multi_sample_guide=False,
**kwargs,
):
def single_particle_elbo(rng_key):
Expand All @@ -150,7 +157,7 @@ def single_particle_elbo(rng_key):
if site["type"] == "mutable"
}
params.update(mutable_params)
if multi_sample_guide:
if self.multi_sample_guide:
plates = {
name: site["value"]
for name, site in guide_trace.items()
Expand Down
22 changes: 5 additions & 17 deletions numpyro/infer/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,12 @@ class SVI(object):
:return: tuple of `(init_fn, update_fn, evaluate)`.
"""

def __init__(
self, model, guide, optim, loss, multi_sample_guide=False, **static_kwargs
):
def __init__(self, model, guide, optim, loss, **static_kwargs):
self.model = model
self.guide = guide
self.loss = loss
self.static_kwargs = static_kwargs
self.constrain_fn = None
self.multi_sample_guide = multi_sample_guide

if isinstance(optim, _NumPyroOptim):
self.optim = optim
Expand Down Expand Up @@ -193,7 +190,7 @@ def init(self, rng_key, *args, init_params=None, **kwargs):
}
if init_params is not None:
init_guide_params.update(init_params)
if self.multi_sample_guide:
if self.loss.multi_sample_guide:
latents = {
name: site["value"][0]
for name, site in guide_trace.items()
Expand Down Expand Up @@ -272,9 +269,6 @@ def update(self, svi_state, *args, **kwargs):
:return: tuple of `(svi_state, loss)`.
"""
rng_key, rng_key_step = random.split(svi_state.rng_key)
static_kwargs = self.static_kwargs.copy()
if self.multi_sample_guide:
static_kwargs["multi_sample_guide"] = True
loss_fn = _make_loss_fn(
self.loss,
rng_key_step,
Expand All @@ -283,7 +277,7 @@ def update(self, svi_state, *args, **kwargs):
self.guide,
args,
kwargs,
static_kwargs,
self.static_kwargs,
mutable_state=svi_state.mutable_state,
)
(loss_val, mutable_state), optim_state = self.optim.eval_and_update(
Expand All @@ -304,9 +298,6 @@ def stable_update(self, svi_state, *args, **kwargs):
:return: tuple of `(svi_state, loss)`.
"""
rng_key, rng_key_step = random.split(svi_state.rng_key)
static_kwargs = self.static_kwargs.copy()
if self.multi_sample_guide:
static_kwargs["multi_sample_guide"] = True
loss_fn = _make_loss_fn(
self.loss,
rng_key_step,
Expand All @@ -315,7 +306,7 @@ def stable_update(self, svi_state, *args, **kwargs):
self.guide,
args,
kwargs,
static_kwargs,
self.static_kwargs,
mutable_state=svi_state.mutable_state,
)
(loss_val, mutable_state), optim_state = self.optim.eval_and_stable_update(
Expand Down Expand Up @@ -428,15 +419,12 @@ def evaluate(self, svi_state, *args, **kwargs):
# we split to have the same seed as `update_fn` given an svi_state
_, rng_key_eval = random.split(svi_state.rng_key)
params = self.get_params(svi_state)
static_kwargs = self.static_kwargs.copy()
if self.multi_sample_guide:
static_kwargs["multi_sample_guide"] = True
return self.loss.loss(
rng_key_eval,
params,
self.model,
self.guide,
*args,
**kwargs,
**static_kwargs,
**self.static_kwargs,
)
13 changes: 8 additions & 5 deletions test/infer/test_svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,16 +741,19 @@ def guide(difficulty=0.0):


def test_multi_sample_guide():
actual_loc = 3.0
actual_scale = 2.0

def model():
numpyro.sample("x", dist.Normal(2, 3))
numpyro.sample("x", dist.Normal(actual_loc, actual_scale))

def guide():
loc = numpyro.param("loc", 0.0)
scale = numpyro.param("scale", 1.0, constraint=constraints.positive)
numpyro.sample("x", dist.Normal(loc, scale).expand([10]))

svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(), multi_sample_guide=True)
svi_results = svi.run(random.PRNGKey(0), 1000)
svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(multi_sample_guide=True))
svi_results = svi.run(random.PRNGKey(0), 2000)
params = svi_results.params
assert_allclose(params["loc"], 2.0, rtol=0.1)
assert_allclose(params["scale"], 3.0, rtol=0.1)
assert_allclose(params["loc"], actual_loc, rtol=0.1)
assert_allclose(params["scale"], actual_scale, rtol=0.1)

0 comments on commit 403192b

Please sign in to comment.