From 9a0d66268ee09a8c0373f07c17c11d38188883fb Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Wed, 28 Jul 2021 15:17:12 -0400 Subject: [PATCH 1/5] Modify get_transform for Laplace guide --- numpyro/infer/autoguide.py | 10 +++------- test/infer/test_autoguide.py | 8 +++++--- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 89bcf8349..91b810ede 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -8,7 +8,7 @@ import numpy as np -from jax import hessian, lax, random, tree_map +from jax import hessian, jacfwd, lax, random, tree_map from jax.experimental import stax from jax.flatten_util import ravel_pytree import jax.numpy as jnp @@ -886,12 +886,8 @@ def loss_fn(z): scale_tril = cholesky_of_inverse(precision) if not_jax_tracer(scale_tril): if np.any(np.isnan(scale_tril)): - warnings.warn( - "Hessian of log posterior at the MAP point is singular. Posterior" - " samples from AutoLaplaceApproxmiation will be constant (equal to" - " the MAP point)." - ) - scale_tril = jnp.where(jnp.isnan(scale_tril), 0.0, scale_tril) + jacobian = jacfwd(loss_fn)(loc) + scale_tril = jnp.outer(jacobian.T, jacobian) return LowerCholeskyAffine(loc, scale_tril) def get_posterior(self, params): diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index 5f162d5e6..9cd62964b 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -3,6 +3,7 @@ from functools import partial +import numpy as np from numpy.testing import assert_allclose import pytest @@ -335,7 +336,7 @@ def expected_model(data): assert_allclose(actual_loss, expected_loss) -def test_laplace_approximation_warning(): +def test_laplace_approximation_uses_gauss_newton_hessian(): def model(x, y): a = numpyro.sample("a", dist.Normal(0, 10)) b = numpyro.sample("b", dist.Normal(0, 10), sample_shape=(3,)) @@ -349,8 +350,9 @@ def model(x, y): init_state = svi.init(random.PRNGKey(0)) svi_state = fori_loop(0, 10000, lambda i, val: svi.update(val)[0], init_state) params = svi.get_params(svi_state) - with pytest.warns(UserWarning, match="Hessian of log posterior"): - guide.sample_posterior(random.PRNGKey(1), params) + samples = guide.sample_posterior(random.PRNGKey(1), params) + # NaNs would be returned if we tried to invert the hessian + assert not np.any([np.isnan(v).any() for v in samples.values()]) def test_improper(): From 338deaa10c2ece9ef7e6e9b24990027231453d44 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 20 Aug 2021 11:08:44 -0400 Subject: [PATCH 2/5] Recommend AutoNormal guide --- numpyro/infer/autoguide.py | 8 ++++++-- test/infer/test_autoguide.py | 5 ++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 91b810ede..d1f305ea6 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -886,8 +886,12 @@ def loss_fn(z): scale_tril = cholesky_of_inverse(precision) if not_jax_tracer(scale_tril): if np.any(np.isnan(scale_tril)): - jacobian = jacfwd(loss_fn)(loc) - scale_tril = jnp.outer(jacobian.T, jacobian) + warnings.warn( + "Hessian of log posterior at the MAP point is singular. Posterior" + " samples from AutoLaplaceApproxmiation will be constant (equal to" + " the MAP point). Please consider using an AutoNormal guide." + ) + scale_tril = jnp.where(jnp.isnan(scale_tril), 0.0, scale_tril) return LowerCholeskyAffine(loc, scale_tril) def get_posterior(self, params): diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index 9cd62964b..3b7dca181 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -350,9 +350,8 @@ def model(x, y): init_state = svi.init(random.PRNGKey(0)) svi_state = fori_loop(0, 10000, lambda i, val: svi.update(val)[0], init_state) params = svi.get_params(svi_state) - samples = guide.sample_posterior(random.PRNGKey(1), params) - # NaNs would be returned if we tried to invert the hessian - assert not np.any([np.isnan(v).any() for v in samples.values()]) + with pytest.warns(UserWarning, match="Hessian of log posterior"): + guide.sample_posterior(random.PRNGKey(1), params) def test_improper(): From 2e6d9684ec35becc5e6a0c05b9ea9a347c0baef2 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 20 Aug 2021 11:12:55 -0400 Subject: [PATCH 3/5] Revert changes to imports --- numpyro/infer/autoguide.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index d1f305ea6..24156b9fc 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -8,7 +8,7 @@ import numpy as np -from jax import hessian, jacfwd, lax, random, tree_map +from jax import hessian, lax, random, tree_map from jax.experimental import stax from jax.flatten_util import ravel_pytree import jax.numpy as jnp @@ -891,7 +891,7 @@ def loss_fn(z): " samples from AutoLaplaceApproxmiation will be constant (equal to" " the MAP point). Please consider using an AutoNormal guide." ) - scale_tril = jnp.where(jnp.isnan(scale_tril), 0.0, scale_tril) + scale_tril = jnp.where(jnp.isnan(scale_tril), 0.0, scale_tril) return LowerCholeskyAffine(loc, scale_tril) def get_posterior(self, params): From 5ef71f96ed3bbecdf8aa4415effafa317a3d1f2a Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 20 Aug 2021 11:14:52 -0400 Subject: [PATCH 4/5] Lint --- test/infer/test_autoguide.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index 3b7dca181..e34effea2 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -3,7 +3,6 @@ from functools import partial -import numpy as np from numpy.testing import assert_allclose import pytest From 0b1c7e22f2d7bf0c8bbb4fe42c5bec02eee2a284 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 20 Aug 2021 11:31:57 -0400 Subject: [PATCH 5/5] Revert change to test name --- test/infer/test_autoguide.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index e34effea2..5f162d5e6 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -335,7 +335,7 @@ def expected_model(data): assert_allclose(actual_loss, expected_loss) -def test_laplace_approximation_uses_gauss_newton_hessian(): +def test_laplace_approximation_warning(): def model(x, y): a = numpyro.sample("a", dist.Normal(0, 10)) b = numpyro.sample("b", dist.Normal(0, 10), sample_shape=(3,))