From 428dee934612ed263f736637ad928207e60862d5 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 10 Jul 2023 11:36:30 -0400 Subject: [PATCH] Support for VAE in AutoSemiDAIS (#1619) * support model without global variables in AutoSemiDAIS * add test for autosemidais local only * black * fix typo in docs * support for vae in semidais * fix bug using wrong sign of potential energy * no need to store prototype local model trace * add docs for local_guide in semidais * allow params in local model * fix wrong scale at z0 * add comment for why we divide by subsample_size at z_0 log prob * address comment --- numpyro/infer/autoguide.py | 161 ++++++++++++++++++++++++----------- numpyro/infer/util.py | 2 + test/infer/test_autoguide.py | 6 +- 3 files changed, 119 insertions(+), 50 deletions(-) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 7991f5417..2d3f0b796 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -47,7 +47,11 @@ from numpyro.infer import Predictive from numpyro.infer.elbo import Trace_ELBO from numpyro.infer.initialization import init_to_median, init_to_uniform -from numpyro.infer.util import helpful_support_errors, initialize_model +from numpyro.infer.util import ( + helpful_support_errors, + initialize_model, + potential_energy, +) from numpyro.nn.auto_reg_nn import AutoregressiveNN from numpyro.nn.block_neural_arn import BlockNeuralAutoregressiveNN from numpyro.util import not_jax_tracer @@ -1134,6 +1138,8 @@ def local_model(theta): :param callable global_guide: A guide for the global latent variables, e.g. an autoguide. The return type should be a dictionary of latent sample sites names and corresponding samples. If there is no global variable in the model, we can set this to None. + :param callable local_guide: An optional guide for specifying the DAIS base distribution for + local latent variables. :param str prefix: A prefix that will be prefixed to all internal sites. :param int K: A positive integer that controls the number of HMC steps used. Defaults to 4. @@ -1152,6 +1158,7 @@ def __init__( model, local_model, global_guide, + local_guide=None, *, prefix="auto", K=4, @@ -1177,6 +1184,7 @@ def __init__( self.local_model = local_model self.global_guide = global_guide + self.local_guide = local_guide self.eta_init = eta_init self.eta_max = eta_max self.gamma_init = gamma_init @@ -1186,6 +1194,7 @@ def __init__( def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # extract global/local/local_dim/plates + assert self.prototype_trace is not None subsample_plates = { name: site for name, site in self.prototype_trace.items() @@ -1225,9 +1234,10 @@ def _setup_prototype(self, *args, **kwargs): for k, v in local_init_locs.items() } _, shape_dict = _ravel_dict(one_sample) - local_init_latent = jax.vmap( + self._pack_local_latent = jax.vmap( lambda x: _ravel_dict(x)[0], in_axes=(subsample_axes,) - )(local_init_locs) + ) + local_init_latent = self._pack_local_latent(local_init_locs) unpack_latent = partial(_unravel_dict, shape_dict=shape_dict) # this is to match the behavior of Pyro, where we can apply # unpack_latent for a batch of samples @@ -1246,23 +1256,14 @@ def _setup_prototype(self, *args, **kwargs): local_args = args local_kwargs = kwargs.copy() - with handlers.block(): - local_kwargs["_subsample_idx"] = { - plate_name: subsample_plates[plate_name]["value"] - } - ( - _, - self._local_potential_fn_gen, - self._local_postprecess_fn, - _, - ) = initialize_model( - random.PRNGKey(0), - partial(_subsample_model, self.local_model), - init_strategy=self.init_loc_fn, - dynamic_args=True, - model_args=local_args, - model_kwargs=local_kwargs, - ) + if self.local_guide is not None: + with handlers.block(), handlers.trace() as tr, handlers.seed(rng_seed=0): + self.local_guide(*local_args, **local_kwargs) + self.prototype_local_guide_trace = tr + + with handlers.block(), handlers.trace() as tr, handlers.seed(rng_seed=0): + self.local_model(*local_args, **local_kwargs) + self.prototype_local_model_trace = tr def __call__(self, *args, **kwargs): if self.prototype_trace is None: @@ -1305,16 +1306,6 @@ def _get_posterior(self): def _sample_latent(self, *args, **kwargs): kwargs.pop("sample_shape", ()) - def make_local_log_density(*local_args, **local_kwargs): - def fn(x): - x_unpack = self._unpack_local_latent(x) - with numpyro.handlers.block(): - return -self._local_potential_fn_gen(*local_args, **local_kwargs)( - x_unpack - ) - - return fn - if self.global_guide is not None: global_latents = self.global_guide(*args, **kwargs) rng_key = numpyro.prng_key() @@ -1329,6 +1320,34 @@ def fn(x): local_args = args local_kwargs = kwargs.copy() + local_guide_params = {} + if self.local_guide is not None: + for name, site in self.prototype_local_guide_trace.items(): + if site["type"] == "param": + local_guide_params[name] = numpyro.param( + name, site["value"], **site["kwargs"] + ) + + local_model_params = {} + for name, site in self.prototype_local_model_trace.items(): + if site["type"] == "param": + local_model_params[name] = numpyro.param( + name, site["value"], **site["kwargs"] + ) + + def make_local_log_density(*local_args, **local_kwargs): + def fn(x): + x_unpack = self._unpack_local_latent(x) + with numpyro.handlers.block(): + return -potential_energy( + partial(_subsample_model, self.local_model), + local_args, + local_kwargs, + {**x_unpack, **local_model_params}, + ) + + return fn + plate_name, N, subsample_size = self._local_plate D, K = self._local_latent_dim, self.K @@ -1366,25 +1385,70 @@ def fn(x): ) inv_mass_matrix = 0.5 / mass_matrix assert inv_mass_matrix.shape == (subsample_size, D) - z_0_loc_init = jnp.zeros((N, D)) - z_0_loc = numpyro.param( - "{}_z_0_loc".format(self.prefix), z_0_loc_init, event_dim=1 - ) - z_0_scale_init = jnp.ones((N, D)) * self.init_scale - z_0_scale = numpyro.param( - "{}_z_0_scale".format(self.prefix), - z_0_scale_init, - constraint=constraints.positive, - event_dim=1, - ) - base_z_dist = dist.Normal(z_0_loc, z_0_scale).to_event(1) - assert base_z_dist.shape() == (subsample_size, D) - z_0 = numpyro.sample( - "{}_z_0".format(self.prefix), base_z_dist, infer={"is_auxiliary": True} - ) - def base_z_dist_log_prob(x): - return base_z_dist.log_prob(x).sum() + local_kwargs["_subsample_idx"] = {plate_name: idx} + if self.local_guide is not None: + key = numpyro.prng_key() + subsample_guide = partial(_subsample_model, self.local_guide) + with handlers.block(), handlers.trace() as tr, handlers.seed( + rng_seed=key + ), handlers.substitute(data=local_guide_params): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + subsample_guide(*local_args, **local_kwargs) + latent = { + name: biject_to(site["fn"].support).inv(site["value"]) + for name, site in tr.items() + if site["type"] == "sample" + and not site.get("is_observed", False) + } + z_0 = self._pack_local_latent(latent) + + def base_z_dist_log_prob(z): + latent = self._unpack_local_latent(z) + assert isinstance(latent, dict) + with handlers.block(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + scale = N / subsample_size + return ( + -potential_energy( + subsample_guide, + local_args, + local_kwargs, + {**local_guide_params, **latent}, + ) + / scale + ) + + # The log_prob of z_0 will be broadcasted to `subsample_size` because this statement + # is run under the subsample plate. Hence we divide the log_prob by `subsample_size`. + numpyro.factor( + "{}_z_0_factor".format(self.prefix), + base_z_dist_log_prob(z_0) / subsample_size, + ) + else: + z_0_loc_init = jnp.zeros((N, D)) + z_0_loc = numpyro.param( + "{}_z_0_loc".format(self.prefix), z_0_loc_init, event_dim=1 + ) + z_0_scale_init = jnp.ones((N, D)) * self.init_scale + z_0_scale = numpyro.param( + "{}_z_0_scale".format(self.prefix), + z_0_scale_init, + constraint=constraints.positive, + event_dim=1, + ) + base_z_dist = dist.Normal(z_0_loc, z_0_scale).to_event(1) + assert base_z_dist.shape() == (subsample_size, D) + z_0 = numpyro.sample( + "{}_z_0".format(self.prefix), + base_z_dist, + infer={"is_auxiliary": True}, + ) + + def base_z_dist_log_prob(x): + return base_z_dist.log_prob(x).sum() momentum_dist = dist.Normal(0, mass_matrix).to_event(1) eps = numpyro.sample( @@ -1396,7 +1460,6 @@ def base_z_dist_log_prob(x): infer={"is_auxiliary": True}, ) - local_kwargs["_subsample_idx"] = {plate_name: idx} local_log_density = make_local_log_density(*local_args, **local_kwargs) def scan_body(carry, eps_beta): diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 4a343510b..92c19034b 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -236,6 +236,8 @@ def unconstrain_fn(model, model_args, model_kwargs, params): def _unconstrain_reparam(params, site): name = site["name"] if name in params: + if site["type"] != "sample": + return params[name] p = params[name] support = site["fn"].support with helpful_support_errors(site): diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index 7f46f4d87..51374fedb 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -837,7 +837,11 @@ def model(): batch = numpyro.subsample(data, event_dim=0) numpyro.sample("x", dist.Normal(batch, 1)) - guide = AutoSemiDAIS(model, model, None) + def create_plates(): + return numpyro.plate("N", 10, subsample_size=5, dim=-1) + + local_guide = AutoNormal(model, create_plates=create_plates) + guide = AutoSemiDAIS(model, model, None, local_guide=local_guide) svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO()) svi_result = svi.run(random.PRNGKey(0), 10) samples = guide.sample_posterior(