From bc82606109a1c352717c8708d8160b9554441a97 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 8 Nov 2021 21:48:04 -0500 Subject: [PATCH 1/6] Change coordinatization of AutoMultivariateNormal --- pyro/infer/autoguide/guides.py | 19 +++++++++++------ tests/infer/reparam/test_neutra.py | 34 ++++++++++++------------------ 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 27d6ac55af..e9c571f886 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -862,6 +862,7 @@ class AutoMultivariateNormal(AutoContinuous): (unconstrained transformed) latent variable. """ + scale_constraint = constraints.softplus_positive scale_tril_constraint = constraints.softplus_lower_cholesky def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1): @@ -874,27 +875,31 @@ def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # Initialize guide params self.loc = nn.Parameter(self._init_loc()) + self.scale = PyroParam( + torch.full_like(self.loc, self._init_scale), self.scale_constraint + ) self.scale_tril = PyroParam( - eye_like(self.loc, self.latent_dim) * self._init_scale, - self.scale_tril_constraint, + eye_like(self.loc, self.latent_dim), self.scale_tril_constraint ) def get_base_dist(self): return dist.Normal( - torch.zeros_like(self.loc), torch.zeros_like(self.loc) + torch.zeros_like(self.loc), torch.ones_like(self.loc) ).to_event(1) def get_transform(self, *args, **kwargs): - return dist.transforms.LowerCholeskyAffine(self.loc, scale_tril=self.scale_tril) + scale_tril = self.scale[..., None] * self.scale_tril + return dist.transforms.LowerCholeskyAffine(self.loc, scale_tril=scale_tril) def get_posterior(self, *args, **kwargs): """ Returns a MultivariateNormal posterior distribution. """ - return dist.MultivariateNormal(self.loc, scale_tril=self.scale_tril) + scale_tril = self.scale[..., None] * self.scale_tril + return dist.MultivariateNormal(self.loc, scale_tril=scale_tril) def _loc_scale(self, *args, **kwargs): - return self.loc, self.scale_tril.diag() + return self.loc, self.scale * self.scale_tril.diag() class AutoDiagonalNormal(AutoContinuous): @@ -937,7 +942,7 @@ def _setup_prototype(self, *args, **kwargs): def get_base_dist(self): return dist.Normal( - torch.zeros_like(self.loc), torch.zeros_like(self.loc) + torch.zeros_like(self.loc), torch.ones_like(self.loc) ).to_event(1) def get_transform(self, *args, **kwargs): diff --git a/tests/infer/reparam/test_neutra.py b/tests/infer/reparam/test_neutra.py index ce9491ef7f..c3eda35b8e 100644 --- a/tests/infer/reparam/test_neutra.py +++ b/tests/infer/reparam/test_neutra.py @@ -9,10 +9,14 @@ from pyro import optim from pyro.distributions.transforms import ComposeTransform from pyro.infer import MCMC, NUTS, SVI, Trace_ELBO -from pyro.infer.autoguide import AutoIAFNormal +from pyro.infer.autoguide import ( + AutoDiagonalNormal, + AutoIAFNormal, + AutoMultivariateNormal, +) from pyro.infer.mcmc.util import initialize_model from pyro.infer.reparam import NeuTraReparam -from tests.common import assert_close, xfail_param +from tests.common import assert_close from .util import check_init_reparam @@ -31,25 +35,22 @@ def dirichlet_categorical(data): return p_latent +@pytest.mark.parametrize("jit", [False, True]) @pytest.mark.parametrize( - "jit", - [ - False, - xfail_param(True, reason="https://github.com/pyro-ppl/pyro/issues/2292"), - ], + "Guide", [AutoDiagonalNormal, AutoMultivariateNormal, AutoIAFNormal], ) -def test_neals_funnel_smoke(jit): +def test_neals_funnel_smoke(Guide, jit): dim = 10 - guide = AutoIAFNormal(neals_funnel) + guide = Guide(neals_funnel) svi = SVI(neals_funnel, guide, optim.Adam({"lr": 1e-10}), Trace_ELBO()) - for _ in range(1000): + for _ in range(10): svi.step(dim) neutra = NeuTraReparam(guide.requires_grad_(False)) model = neutra.reparam(neals_funnel) - nuts = NUTS(model, jit_compile=jit) - mcmc = MCMC(nuts, num_samples=50, warmup_steps=50) + nuts = NUTS(model, jit_compile=jit, ignore_jit_warnings=True) + mcmc = MCMC(nuts, num_samples=10, warmup_steps=10) mcmc.run(dim) samples = mcmc.get_samples() # XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites, not uniformly at -max_plate_nesting-1; @@ -65,14 +66,7 @@ def test_neals_funnel_smoke(jit): "model, kwargs", [ (neals_funnel, {"dim": 10}), - ( - dirichlet_categorical, - { - "data": torch.ones( - 10, - ) - }, - ), + (dirichlet_categorical, {"data": torch.ones(10)}), ], ) def test_reparam_log_joint(model, kwargs): From 844b7158d03569845ddebe6a39ba4d29c3b0be0a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 8 Nov 2021 21:53:23 -0500 Subject: [PATCH 2/6] lint --- tests/infer/reparam/test_neutra.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/infer/reparam/test_neutra.py b/tests/infer/reparam/test_neutra.py index c3eda35b8e..397ecc860f 100644 --- a/tests/infer/reparam/test_neutra.py +++ b/tests/infer/reparam/test_neutra.py @@ -37,7 +37,8 @@ def dirichlet_categorical(data): @pytest.mark.parametrize("jit", [False, True]) @pytest.mark.parametrize( - "Guide", [AutoDiagonalNormal, AutoMultivariateNormal, AutoIAFNormal], + "Guide", + [AutoDiagonalNormal, AutoMultivariateNormal, AutoIAFNormal], ) def test_neals_funnel_smoke(Guide, jit): dim = 10 From 3b2f61b1dda94e975a7ba685a9798e66a7d9d6d3 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 9 Nov 2021 17:45:15 -0500 Subject: [PATCH 3/6] Fix AutoLaplaceApproximation --- pyro/infer/autoguide/guides.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index e9c571f886..7857deb215 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -1172,14 +1172,21 @@ def laplace_approximation(self, *args, **kwargs): loss = guide_trace.log_prob_sum() - model_trace.log_prob_sum() H = hessian(loss, self.loc) - cov = H.inverse() - loc = self.loc - scale_tril = torch.linalg.cholesky(cov) + with torch.no_grad(): + loc = self.loc.detach() + cov = H.inverse().detach() + scale_squared = cov.diagonal() + scale = scale_squared.sqrt() + scale_tril = torch.linalg.cholesky(cov / scale_squared[:, None]) gaussian_guide = AutoMultivariateNormal(self.model) gaussian_guide._setup_prototype(*args, **kwargs) - # Set loc, scale_tril parameters as computed above. + # Set loc, scale, scale_tril parameters as computed above. + del gaussian_guide.loc + del gaussian_guide.scale + del gaussian_guide.scale_tril gaussian_guide.loc = loc + gaussian_guide.scale = scale gaussian_guide.scale_tril = scale_tril return gaussian_guide From c9d3e687c7feffe9c91dbcedb616e8d9e0ba7b41 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 9 Nov 2021 17:49:38 -0500 Subject: [PATCH 4/6] Fix AutoLaplaceApproximation --- pyro/infer/autoguide/guides.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 7857deb215..467e9eb003 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -1174,20 +1174,21 @@ def laplace_approximation(self, *args, **kwargs): H = hessian(loss, self.loc) with torch.no_grad(): loc = self.loc.detach() - cov = H.inverse().detach() - scale_squared = cov.diagonal() - scale = scale_squared.sqrt() - scale_tril = torch.linalg.cholesky(cov / scale_squared[:, None]) + cov = H.inverse() + scale = cov.diagonal().sqrt() + cov /= scale[:, None] + cov /= scale[None, :] + scale_tril = torch.linalg.cholesky(cov) gaussian_guide = AutoMultivariateNormal(self.model) gaussian_guide._setup_prototype(*args, **kwargs) - # Set loc, scale, scale_tril parameters as computed above. + # Set detached loc, scale, scale_tril parameters as computed above. del gaussian_guide.loc del gaussian_guide.scale del gaussian_guide.scale_tril - gaussian_guide.loc = loc - gaussian_guide.scale = scale - gaussian_guide.scale_tril = scale_tril + gaussian_guide.register_buffer("loc", loc) + gaussian_guide.register_buffer("scale", scale) + gaussian_guide.register_buffer("scale_tril", scale_tril) return gaussian_guide From bca2755a17e3cd2cdc20c4d163486defc5edd768 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 9 Nov 2021 19:42:51 -0500 Subject: [PATCH 5/6] Fix test --- tests/contrib/autoguide/test_inference.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/contrib/autoguide/test_inference.py b/tests/contrib/autoguide/test_inference.py index 1a722f6cf6..228c8443ad 100644 --- a/tests/contrib/autoguide/test_inference.py +++ b/tests/contrib/autoguide/test_inference.py @@ -49,7 +49,7 @@ def compute_target(self, N): ) * self.target_auto_diag_cov[n + 1] def test_multivariatate_normal_auto(self): - self.do_test_auto(3, reparameterized=True, n_steps=10001) + self.do_test_auto(3, reparameterized=True, n_steps=1001) def do_test_auto(self, N, reparameterized, n_steps): logger.debug("\nGoing to do AutoGaussianChain test...") @@ -70,20 +70,21 @@ def do_test_auto(self, N, reparameterized, n_steps): ) # TODO speed up with parallel num_particles > 1 - adam = optim.Adam({"lr": 0.001, "betas": (0.95, 0.999)}) - svi = SVI(self.model, self.guide, adam, loss=Trace_ELBO()) + adam = optim.Adam({"lr": 0.01, "betas": (0.95, 0.999)}) + elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) + svi = SVI(self.model, self.guide, adam, elbo) for k in range(n_steps): loss = svi.step(reparameterized) assert np.isfinite(loss), loss - if k % 1000 == 0 and k > 0 or k == n_steps - 1: + if k % 100 == 0 and k > 0 or k == n_steps - 1: logger.debug( "[step {}] guide mean parameter: {}".format( k, self.guide.loc.detach().cpu().numpy() ) ) - L = self.guide.scale_tril + L = self.guide.scale_tril * self.guide.scale[:, None] diag_cov = torch.mm(L, L.t()).diag() logger.debug( "[step {}] auto_diag_cov: {}".format( From 06baa1c03279c65a966d38174a653e988bf6f779 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 9 Nov 2021 19:43:59 -0500 Subject: [PATCH 6/6] Move autoguide tests --- tests/{contrib => infer}/autoguide/test_inference.py | 0 tests/{contrib => infer}/autoguide/test_mean_field_entropy.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/{contrib => infer}/autoguide/test_inference.py (100%) rename tests/{contrib => infer}/autoguide/test_mean_field_entropy.py (100%) diff --git a/tests/contrib/autoguide/test_inference.py b/tests/infer/autoguide/test_inference.py similarity index 100% rename from tests/contrib/autoguide/test_inference.py rename to tests/infer/autoguide/test_inference.py diff --git a/tests/contrib/autoguide/test_mean_field_entropy.py b/tests/infer/autoguide/test_mean_field_entropy.py similarity index 100% rename from tests/contrib/autoguide/test_mean_field_entropy.py rename to tests/infer/autoguide/test_mean_field_entropy.py