diff --git a/numpyro/handlers.py b/numpyro/handlers.py index b03b6c4f5..d97c96f95 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -447,6 +447,7 @@ def process_message(self, msg): msg["kwargs"] = {"rng_key": msg["kwargs"].get("rng_key", None), "sample_shape": msg["kwargs"].get("sample_shape", ())} msg["intermediates"] = [] + msg["infer"] = msg.get("infer", {}) else: # otherwise leave as is return diff --git a/test/test_mcmc.py b/test/test_mcmc.py index 1a667337c..ea955b705 100644 --- a/test/test_mcmc.py +++ b/test/test_mcmc.py @@ -681,3 +681,14 @@ def model(): # this fails in reverse mode mcmc = MCMC(NUTS(model, forward_mode_differentiation=True), 10, 10) mcmc.run(random.PRNGKey(0)) + + +def test_model_with_lift_handler(): + def model(data): + c = numpyro.param("c", jnp.array(1.), constraint=dist.constraints.positive) + x = numpyro.sample("x", dist.LogNormal(c, 1.), obs=data) + return x + + nuts_kernel = NUTS(numpyro.handlers.lift(model, prior={"c": dist.Gamma(0.01, 0.01)})) + mcmc = MCMC(nuts_kernel, num_warmup=10, num_samples=10) + mcmc.run(random.PRNGKey(1), jnp.exp(random.normal(random.PRNGKey(0), (1000,))))