From d018a19220cd3f699fbc8d68233a5e937ebad7bb Mon Sep 17 00:00:00 2001 From: Dominik Straub Date: Mon, 25 Jan 2021 15:27:59 +0100 Subject: [PATCH 1/6] Fix missing infer key --- numpyro/handlers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index b03b6c4f5..d97c96f95 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -447,6 +447,7 @@ def process_message(self, msg): msg["kwargs"] = {"rng_key": msg["kwargs"].get("rng_key", None), "sample_shape": msg["kwargs"].get("sample_shape", ())} msg["intermediates"] = [] + msg["infer"] = msg.get("infer", {}) else: # otherwise leave as is return From 66ba33aecb6ffae2b6768b9988546c6abb045c43 Mon Sep 17 00:00:00 2001 From: Dominik Straub Date: Mon, 25 Jan 2021 17:20:16 +0100 Subject: [PATCH 2/6] Add test for lifted model --- test/test_mcmc.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/test/test_mcmc.py b/test/test_mcmc.py index 1a667337c..201bd5d83 100644 --- a/test/test_mcmc.py +++ b/test/test_mcmc.py @@ -69,7 +69,7 @@ def potential_fn(z): mcmc.run(random.PRNGKey(0), init_params=init_params) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples), true_mean, atol=0.02) - assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02 + assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D ** 2 < 0.02 @pytest.mark.parametrize('kernel_cls', [HMC, NUTS, SA]) @@ -213,12 +213,12 @@ def model(data): numpyro.sample('obs', dist.Poisson(lambda12), obs=data) count_data = jnp.array([ - 13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11, 57, - 11, 19, 29, 6, 19, 12, 22, 12, 18, 72, 32, 9, 7, 13, - 19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2, - 15, 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 18, 20, - 12, 35, 17, 23, 17, 4, 2, 31, 30, 13, 27, 0, 39, 37, - 5, 14, 13, 22, + 13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11, 57, + 11, 19, 29, 6, 19, 12, 22, 12, 18, 72, 32, 9, 7, 13, + 19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2, + 15, 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 18, 20, + 12, 35, 17, 23, 17, 4, 2, 31, 30, 13, 27, 0, 39, 37, + 5, 14, 13, 22, ]) kernel = NUTS(model=model) mcmc = MCMC(kernel, warmup_steps, num_samples) @@ -412,6 +412,7 @@ def test_chain_inside_jit(kernel_cls, chain_method): step_size = 1. target_accept_prob = 0.8 trajectory_length = 1. + # Not supported yet: # + adapt_step_size # + adapt_mass_matrix @@ -681,3 +682,14 @@ def model(): # this fails in reverse mode mcmc = MCMC(NUTS(model, forward_mode_differentiation=True), 10, 10) mcmc.run(random.PRNGKey(0)) + + +def test_model_with_lift_handler(): + def model(data): + c = numpyro.param("c", jnp.array(1.), constraint=dist.constraints.positive) + x = numpyro.sample("x", dist.LogNormal(c, 1.), obs=data) + return x + + nuts_kernel = NUTS(numpyro.handlers.lift(model, prior={"c": dist.Gamma(0.01, 0.01)})) + mcmc = MCMC(nuts_kernel, num_warmup=10, num_samples=10) + mcmc.run(random.PRNGKey(1), jnp.exp(random.normal(random.PRNGKey(0), (1000,)))) From 8b44e33f08ee720cc571ed95ff0f1ca62fe7c17d Mon Sep 17 00:00:00 2001 From: Dominik Straub Date: Mon, 25 Jan 2021 17:31:55 +0100 Subject: [PATCH 3/6] Add test for lifted model --- test/test_mcmc.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/test_mcmc.py b/test/test_mcmc.py index 1a667337c..02c5d8806 100644 --- a/test/test_mcmc.py +++ b/test/test_mcmc.py @@ -681,3 +681,14 @@ def model(): # this fails in reverse mode mcmc = MCMC(NUTS(model, forward_mode_differentiation=True), 10, 10) mcmc.run(random.PRNGKey(0)) + + +def test_model_with_lift_handler(): + def model(data): + c = numpyro.param("c", jnp.array(1.), constraint=dist.constraints.positive) + x = numpyro.sample("x", dist.LogNormal(c, 1.), obs=data) + return x + + nuts_kernel = NUTS(numpyro.handlers.lift(model, prior={"c": dist.Gamma(0.01, 0.01)})) + mcmc = MCMC(nuts_kernel, num_warmup=10, num_samples=10) + mcmc.run(random.PRNGKey(1), jnp.exp(random.normal(random.PRNGKey(0), (1000,)))) \ No newline at end of file From d114628328831eb2531249623f8e3d83908e1204 Mon Sep 17 00:00:00 2001 From: Dominik Straub Date: Mon, 25 Jan 2021 17:38:05 +0100 Subject: [PATCH 4/6] Add test for lifted model --- test/test_mcmc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_mcmc.py b/test/test_mcmc.py index 02c5d8806..144fcd29b 100644 --- a/test/test_mcmc.py +++ b/test/test_mcmc.py @@ -691,4 +691,5 @@ def model(data): nuts_kernel = NUTS(numpyro.handlers.lift(model, prior={"c": dist.Gamma(0.01, 0.01)})) mcmc = MCMC(nuts_kernel, num_warmup=10, num_samples=10) - mcmc.run(random.PRNGKey(1), jnp.exp(random.normal(random.PRNGKey(0), (1000,)))) \ No newline at end of file + mcmc.run(random.PRNGKey(1), jnp.exp(random.normal(random.PRNGKey(0), (1000,)))) + \ No newline at end of file From a5363b47c214c654a03777fdbd9fbf13a698291e Mon Sep 17 00:00:00 2001 From: Dominik Straub Date: Mon, 25 Jan 2021 17:51:35 +0100 Subject: [PATCH 5/6] Add test for lifted model --- test/test_mcmc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_mcmc.py b/test/test_mcmc.py index 144fcd29b..ea955b705 100644 --- a/test/test_mcmc.py +++ b/test/test_mcmc.py @@ -692,4 +692,3 @@ def model(data): nuts_kernel = NUTS(numpyro.handlers.lift(model, prior={"c": dist.Gamma(0.01, 0.01)})) mcmc = MCMC(nuts_kernel, num_warmup=10, num_samples=10) mcmc.run(random.PRNGKey(1), jnp.exp(random.normal(random.PRNGKey(0), (1000,)))) - \ No newline at end of file From 944f2019d0bc8f614db6a9693d79d0eae297b201 Mon Sep 17 00:00:00 2001 From: Dominik Straub Date: Mon, 25 Jan 2021 17:56:53 +0100 Subject: [PATCH 6/6] Fix commits --- test/test_mcmc.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/test_mcmc.py b/test/test_mcmc.py index 1a667337c..ea955b705 100644 --- a/test/test_mcmc.py +++ b/test/test_mcmc.py @@ -681,3 +681,14 @@ def model(): # this fails in reverse mode mcmc = MCMC(NUTS(model, forward_mode_differentiation=True), 10, 10) mcmc.run(random.PRNGKey(0)) + + +def test_model_with_lift_handler(): + def model(data): + c = numpyro.param("c", jnp.array(1.), constraint=dist.constraints.positive) + x = numpyro.sample("x", dist.LogNormal(c, 1.), obs=data) + return x + + nuts_kernel = NUTS(numpyro.handlers.lift(model, prior={"c": dist.Gamma(0.01, 0.01)})) + mcmc = MCMC(nuts_kernel, num_warmup=10, num_samples=10) + mcmc.run(random.PRNGKey(1), jnp.exp(random.normal(random.PRNGKey(0), (1000,))))