From 201df2911eb4521a66e196e30c6529f5f8e6ef3f Mon Sep 17 00:00:00 2001 From: danielward27 Date: Sun, 4 Feb 2024 20:02:58 +0000 Subject: [PATCH] fix elbo normalization with multi_sample_guide=True --- numpyro/infer/elbo.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index 872ac6622..eb4e11b68 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -174,7 +174,7 @@ def get_model_density(key, latent): return model_log_density num_guide_samples = None - for name, site in guide_trace.items(): + for site in guide_trace.values(): if site["type"] == "sample": num_guide_samples = site["value"].shape[0] break @@ -210,8 +210,6 @@ def get_model_density(key, latent): # log p(z) - log q(z) elbo_particle = model_log_density - guide_log_density - # log p(z) - log q(z) - elbo_particle = model_log_density - guide_log_density if mutable_params: if self.num_particles == 1: return elbo_particle, mutable_params