Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Jul 10, 2023
1 parent 22069fc commit 59081d7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
12 changes: 11 additions & 1 deletion numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,16 @@ def _single_particle_elbo(self, model, guide, param_map, args, kwargs, rng_key):
indep_plates = set.intersection(*site_plates.values())
else:
indep_plates = set()
for frame in set.union(*site_plates.values()):
if frame not in indep_plates:
subsample_size = frame.size
size = model_trace[frame.name]["args"][0]
if size > subsample_size:
raise ValueError(
f"Subsample plate `{frame.name}` should cover"
" all random variables."
)

indep_plate_scale = 1.0
for frame in indep_plates:
subsample_size = frame.size
Expand All @@ -380,7 +390,7 @@ def _single_particle_elbo(self, model, guide, param_map, args, kwargs, rng_key):

log_densities = {}
for trace_type, tr in {"guide": guide_trace, "model": model_trace}.items():
log_densities[trace_type] = 1.0
log_densities[trace_type] = 0.0
for site in tr.values():
if site["type"] != "sample":
continue
Expand Down
10 changes: 5 additions & 5 deletions test/infer/test_svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def get_renyi(n=N, k=K, fix_indices=True):


@pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)])
@pytest.mark.parametrize("optimizer", [optim.Adam(0.05), optimizers.adam(0.05)])
@pytest.mark.parametrize("optimizer", [optim.Adam(0.01), optimizers.adam(0.01)])
def test_beta_bernoulli(elbo, optimizer):
data = jnp.array([1.0] * 8 + [0.0] * 2)

Expand All @@ -193,14 +193,14 @@ def body_fn(i, val):
svi_state, _ = svi.update(val, data)
return svi_state

svi_state = fori_loop(0, 4000, body_fn, svi_state)
svi_state = fori_loop(0, 10000, body_fn, svi_state)
params = svi.get_params(svi_state)
actual_posterior_mean = 0.75 # (8 + 1) / (8 + 1 + 2 + 1)
actual_posterior_mean = (data.sum() + 1) / (data.shape[0] + 2)
assert_allclose(
params["alpha_q"] / (params["alpha_q"] + params["beta_q"]),
actual_posterior_mean,
atol=0.05,
rtol=0.05,
atol=0.03,
rtol=0.03,
)


Expand Down

0 comments on commit 59081d7

Please sign in to comment.