diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index f934b88f0..46887c9a9 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -21,7 +21,7 @@ from numpyro.distributions import constraints from numpyro.distributions.transforms import biject_to -from numpyro.handlers import replay, seed, trace +from numpyro.handlers import replay, seed, substitute, trace from numpyro.infer.util import helpful_support_errors, transform_fn from numpyro.optim import _NumPyroOptim, optax_to_numpyro @@ -184,9 +184,15 @@ def init(self, rng_key, *args, **kwargs): model_init = seed(self.model, model_seed) guide_init = seed(self.guide, guide_seed) guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs) - model_trace = trace(replay(model_init, guide_trace)).get_trace( - *args, **kwargs, **self.static_kwargs - ) + init_guide_params = { + name: site["value"] + for name, site in guide_trace.items() + if site["type"] == "param" + } + model_trace = trace( + substitute(replay(model_init, guide_trace), init_guide_params) + ).get_trace(*args, **kwargs, **self.static_kwargs) + params = {} inv_transforms = {} mutable_state = {} diff --git a/test/infer/test_svi.py b/test/infer/test_svi.py index 16d227722..2db18a3ca 100644 --- a/test/infer/test_svi.py +++ b/test/infer/test_svi.py @@ -190,6 +190,40 @@ def guide(): assert_allclose(actual_loss, expected_loss, rtol=1e-6) +def test_shared_param_init(): + shared_init = 1.0 + + def model(): + # should receive initial value from guide when used in SVI + shared = numpyro.param("shared") + assert_allclose(shared, shared_init) + + def guide(): + numpyro.param("shared", lambda _: shared_init) + + svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO()) + svi_state = svi.init(random.PRNGKey(0)) + params = svi.get_params(svi_state) + # make sure the correct init ended up in the SVI state + assert_allclose(params["shared"], shared_init) + + +def test_shared_param(): + target_value = 5.0 + + def model(): + shared = numpyro.param("shared") + # drive the shared parameter toward a target value + numpyro.factor("neg_loss", -((shared - target_value) ** 2)) + + def guide(): + numpyro.param("shared", 1.0) + + svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO()) + svi_result = svi.run(random.PRNGKey(0), 1000) + assert_allclose(svi_result.params["shared"], target_value, atol=0.1) + + def test_elbo_dynamic_support(): x_prior = dist.TransformedDistribution( dist.Normal(),