From 0e50bacfb3b8026861e1f49cc1034d03fe8c6283 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 19 Jun 2023 22:20:43 +0700 Subject: [PATCH] Support model without global variables in AutoSemiDAIS (#1610) * support model without global variables in AutoSemiDAIS * add test for autosemidais local only * black * fix typo in docs --- numpyro/infer/autoguide.py | 60 +++++++++++++++++++++--------------- test/infer/test_autoguide.py | 17 ++++++++++ 2 files changed, 53 insertions(+), 24 deletions(-) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index c0e218575..b8eec09ae 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -1119,8 +1119,8 @@ def local_model(theta): numpyro.sample("obs", dist.Normal(0.0, tau), obs=jnp.ones(2)) model = lambda: local_model(global_model()) - base_guide = AutoNormal(global_model) - guide = AutoSemiDAIS(model, local_model, base_guide, K=4) + global_guide = AutoNormal(global_model) + guide = AutoSemiDAIS(model, local_model, global_guide, K=4) svi = SVI(model, guide, ...) # sample posterior for particular data subset {3, 7} @@ -1131,8 +1131,9 @@ def local_model(theta): :param callable local_model: The portion of `model` that includes the local latent variables only. The signature of `local_model` should be the return type of the global model with global latent variables only. - :param callable base_guide: A guide for the global latent variables, e.g. an autoguide. + :param callable global_guide: A guide for the global latent variables, e.g. an autoguide. The return type should be a dictionary of latent sample sites names and corresponding samples. + If there is no global variable in the model, we can set this to None. :param str prefix: A prefix that will be prefixed to all internal sites. :param int K: A positive integer that controls the number of HMC steps used. Defaults to 4. @@ -1150,7 +1151,7 @@ def __init__( self, model, local_model, - base_guide, + global_guide, *, prefix="auto", K=4, @@ -1175,7 +1176,7 @@ def __init__( raise ValueError("init_scale must be positive.") self.local_model = local_model - self.base_guide = base_guide + self.global_guide = global_guide self.eta_init = eta_init self.eta_max = eta_max self.gamma_init = gamma_init @@ -1237,25 +1238,30 @@ def _setup_prototype(self, *args, **kwargs): self._local_latent_dim = jnp.size(local_init_latent) // plate_subsample_size self._local_plate = (plate_name, plate_full_size, plate_subsample_size) - rng_key = numpyro.prng_key() - with handlers.block(), handlers.seed(rng_seed=rng_key): - global_output = self.base_guide.model(*args, **kwargs) + if self.global_guide is not None: + with handlers.block(), handlers.seed(rng_seed=0): + local_args = (self.global_guide.model(*args, **kwargs),) + local_kwargs = {} + else: + local_args = args + local_kwargs = kwargs.copy() + + with handlers.block(): + local_kwargs["_subsample_idx"] = { + plate_name: subsample_plates[plate_name]["value"] + } ( _, self._local_potential_fn_gen, self._local_postprecess_fn, _, ) = initialize_model( - numpyro.prng_key(), + random.PRNGKey(0), partial(_subsample_model, self.local_model), init_strategy=self.init_loc_fn, dynamic_args=True, - model_args=(global_output,), - model_kwargs={ - "_subsample_idx": { - plate_name: subsample_plates[plate_name]["value"] - } - }, + model_args=local_args, + model_kwargs=local_kwargs, ) def __call__(self, *args, **kwargs): @@ -1309,12 +1315,19 @@ def fn(x): return fn - global_latents = self.base_guide(*args, **kwargs) - rng_key = numpyro.prng_key() - with handlers.block(), handlers.seed(rng_seed=rng_key), handlers.substitute( - data=global_latents - ): - global_output = self.base_guide.model(*args, **kwargs) + if self.global_guide is not None: + global_latents = self.global_guide(*args, **kwargs) + rng_key = numpyro.prng_key() + with handlers.block(), handlers.seed(rng_seed=rng_key), handlers.substitute( + data=global_latents + ): + global_outputs = self.global_guide.model(*args, **kwargs) + local_args = (global_ouputs,) + local_kwargs = {} + else: + global_latents = {} + local_args = args + local_kwargs = kwargs.copy() plate_name, N, subsample_size = self._local_plate D, K = self._local_latent_dim, self.K @@ -1383,9 +1396,8 @@ def base_z_dist_log_prob(x): infer={"is_auxiliary": True}, ) - local_log_density = make_local_log_density( - global_output, _subsample_idx={plate_name: idx} - ) + local_kwargs["_subsample_idx"] = {plate_name: idx} + local_log_density = make_local_log_density(*local_args, **local_kwargs) def scan_body(carry, eps_beta): eps, beta = eps_beta diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index a901ea07a..7f46f4d87 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -829,6 +829,23 @@ def model(): assert samples["sigma"].shape == (5,) and samples["log_sigma"].shape == (5, 2) +def test_autosemidais_local_only(): + data = jnp.linspace(0, 1, 10) + + def model(): + with numpyro.plate("N", 10, subsample_size=5, dim=-1): + batch = numpyro.subsample(data, event_dim=0) + numpyro.sample("x", dist.Normal(batch, 1)) + + guide = AutoSemiDAIS(model, model, None) + svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO()) + svi_result = svi.run(random.PRNGKey(0), 10) + samples = guide.sample_posterior( + random.PRNGKey(1), svi_result.params, sample_shape=(100,) + ) + assert samples["x"].shape == (100, 5) + + def test_autosemidais_inadmissible_smoke(): def global_model(): return numpyro.sample("theta", dist.Normal(0, 1))