Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for VAE in AutoSemiDAIS #1619

Merged
merged 14 commits into from
Jul 10, 2023
161 changes: 112 additions & 49 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@
from numpyro.infer import Predictive
from numpyro.infer.elbo import Trace_ELBO
from numpyro.infer.initialization import init_to_median, init_to_uniform
from numpyro.infer.util import helpful_support_errors, initialize_model
from numpyro.infer.util import (
helpful_support_errors,
initialize_model,
potential_energy,
)
from numpyro.nn.auto_reg_nn import AutoregressiveNN
from numpyro.nn.block_neural_arn import BlockNeuralAutoregressiveNN
from numpyro.util import not_jax_tracer
Expand Down Expand Up @@ -1134,6 +1138,8 @@ def local_model(theta):
:param callable global_guide: A guide for the global latent variables, e.g. an autoguide.
The return type should be a dictionary of latent sample sites names and corresponding samples.
If there is no global variable in the model, we can set this to None.
:param callable local_guide: An optional guide for specifying the DAIS base distribution for
local latent variables.
:param str prefix: A prefix that will be prefixed to all internal sites.
:param int K: A positive integer that controls the number of HMC steps used.
Defaults to 4.
Expand All @@ -1152,6 +1158,7 @@ def __init__(
model,
local_model,
global_guide,
local_guide=None,
*,
prefix="auto",
K=4,
Expand All @@ -1177,6 +1184,7 @@ def __init__(

self.local_model = local_model
self.global_guide = global_guide
self.local_guide = local_guide
self.eta_init = eta_init
self.eta_max = eta_max
self.gamma_init = gamma_init
Expand All @@ -1186,6 +1194,7 @@ def __init__(
def _setup_prototype(self, *args, **kwargs):
super()._setup_prototype(*args, **kwargs)
# extract global/local/local_dim/plates
assert self.prototype_trace is not None
subsample_plates = {
name: site
for name, site in self.prototype_trace.items()
Expand Down Expand Up @@ -1225,9 +1234,10 @@ def _setup_prototype(self, *args, **kwargs):
for k, v in local_init_locs.items()
}
_, shape_dict = _ravel_dict(one_sample)
local_init_latent = jax.vmap(
self._pack_local_latent = jax.vmap(
lambda x: _ravel_dict(x)[0], in_axes=(subsample_axes,)
)(local_init_locs)
)
local_init_latent = self._pack_local_latent(local_init_locs)
unpack_latent = partial(_unravel_dict, shape_dict=shape_dict)
# this is to match the behavior of Pyro, where we can apply
# unpack_latent for a batch of samples
Expand All @@ -1246,23 +1256,14 @@ def _setup_prototype(self, *args, **kwargs):
local_args = args
local_kwargs = kwargs.copy()

with handlers.block():
local_kwargs["_subsample_idx"] = {
plate_name: subsample_plates[plate_name]["value"]
}
(
_,
self._local_potential_fn_gen,
self._local_postprecess_fn,
_,
) = initialize_model(
random.PRNGKey(0),
partial(_subsample_model, self.local_model),
init_strategy=self.init_loc_fn,
dynamic_args=True,
model_args=local_args,
model_kwargs=local_kwargs,
)
if self.local_guide is not None:
with handlers.block(), handlers.trace() as tr, handlers.seed(rng_seed=0):
self.local_guide(*local_args, **local_kwargs)
self.prototype_local_guide_trace = tr

with handlers.block(), handlers.trace() as tr, handlers.seed(rng_seed=0):
self.local_model(*local_args, **local_kwargs)
self.prototype_local_model_trace = tr

def __call__(self, *args, **kwargs):
if self.prototype_trace is None:
Expand Down Expand Up @@ -1305,16 +1306,6 @@ def _get_posterior(self):
def _sample_latent(self, *args, **kwargs):
kwargs.pop("sample_shape", ())

def make_local_log_density(*local_args, **local_kwargs):
def fn(x):
x_unpack = self._unpack_local_latent(x)
with numpyro.handlers.block():
return -self._local_potential_fn_gen(*local_args, **local_kwargs)(
x_unpack
)

return fn

if self.global_guide is not None:
global_latents = self.global_guide(*args, **kwargs)
rng_key = numpyro.prng_key()
Expand All @@ -1329,6 +1320,34 @@ def fn(x):
local_args = args
local_kwargs = kwargs.copy()

local_guide_params = {}
if self.local_guide is not None:
for name, site in self.prototype_local_guide_trace.items():
if site["type"] == "param":
local_guide_params[name] = numpyro.param(
name, site["value"], **site["kwargs"]
)

local_model_params = {}
for name, site in self.prototype_local_model_trace.items():
if site["type"] == "param":
local_model_params[name] = numpyro.param(
name, site["value"], **site["kwargs"]
)

def make_local_log_density(*local_args, **local_kwargs):
def fn(x):
x_unpack = self._unpack_local_latent(x)
with numpyro.handlers.block():
return -potential_energy(
partial(_subsample_model, self.local_model),
local_args,
local_kwargs,
{**x_unpack, **local_model_params},
)

return fn

plate_name, N, subsample_size = self._local_plate
D, K = self._local_latent_dim, self.K

Expand Down Expand Up @@ -1366,25 +1385,70 @@ def fn(x):
)
inv_mass_matrix = 0.5 / mass_matrix
assert inv_mass_matrix.shape == (subsample_size, D)
z_0_loc_init = jnp.zeros((N, D))
z_0_loc = numpyro.param(
"{}_z_0_loc".format(self.prefix), z_0_loc_init, event_dim=1
)
z_0_scale_init = jnp.ones((N, D)) * self.init_scale
z_0_scale = numpyro.param(
"{}_z_0_scale".format(self.prefix),
z_0_scale_init,
constraint=constraints.positive,
event_dim=1,
)
base_z_dist = dist.Normal(z_0_loc, z_0_scale).to_event(1)
assert base_z_dist.shape() == (subsample_size, D)
z_0 = numpyro.sample(
"{}_z_0".format(self.prefix), base_z_dist, infer={"is_auxiliary": True}
)

def base_z_dist_log_prob(x):
return base_z_dist.log_prob(x).sum()
local_kwargs["_subsample_idx"] = {plate_name: idx}
if self.local_guide is not None:
key = numpyro.prng_key()
subsample_guide = partial(_subsample_model, self.local_guide)
with handlers.block(), handlers.trace() as tr, handlers.seed(
rng_seed=key
), handlers.substitute(data=local_guide_params):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
subsample_guide(*local_args, **local_kwargs)
latent = {
name: biject_to(site["fn"].support).inv(site["value"])
for name, site in tr.items()
if site["type"] == "sample"
and not site.get("is_observed", False)
}
z_0 = self._pack_local_latent(latent)

def base_z_dist_log_prob(z):
latent = self._unpack_local_latent(z)
assert isinstance(latent, dict)
with handlers.block():
with warnings.catch_warnings():
warnings.simplefilter("ignore")
scale = N / subsample_size
return (
-potential_energy(
subsample_guide,
local_args,
local_kwargs,
{**local_guide_params, **latent},
)
/ scale
)

# The log_prob of z_0 will be broadcasted to `subsample_size` because this statement
# is run under the subsample plate. Hence we divide the log_prob by `subsample_size`.
numpyro.factor(
"{}_z_0_factor".format(self.prefix),
base_z_dist_log_prob(z_0) / subsample_size,
)
else:
z_0_loc_init = jnp.zeros((N, D))
z_0_loc = numpyro.param(
"{}_z_0_loc".format(self.prefix), z_0_loc_init, event_dim=1
)
z_0_scale_init = jnp.ones((N, D)) * self.init_scale
z_0_scale = numpyro.param(
"{}_z_0_scale".format(self.prefix),
z_0_scale_init,
constraint=constraints.positive,
event_dim=1,
)
base_z_dist = dist.Normal(z_0_loc, z_0_scale).to_event(1)
assert base_z_dist.shape() == (subsample_size, D)
z_0 = numpyro.sample(
"{}_z_0".format(self.prefix),
base_z_dist,
infer={"is_auxiliary": True},
)

def base_z_dist_log_prob(x):
return base_z_dist.log_prob(x).sum()

momentum_dist = dist.Normal(0, mass_matrix).to_event(1)
eps = numpyro.sample(
Expand All @@ -1396,7 +1460,6 @@ def base_z_dist_log_prob(x):
infer={"is_auxiliary": True},
)

local_kwargs["_subsample_idx"] = {plate_name: idx}
local_log_density = make_local_log_density(*local_args, **local_kwargs)

def scan_body(carry, eps_beta):
Expand Down
2 changes: 2 additions & 0 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ def unconstrain_fn(model, model_args, model_kwargs, params):
def _unconstrain_reparam(params, site):
name = site["name"]
if name in params:
if site["type"] != "sample":
return params[name]
p = params[name]
support = site["fn"].support
with helpful_support_errors(site):
Expand Down
6 changes: 5 additions & 1 deletion test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,11 @@ def model():
batch = numpyro.subsample(data, event_dim=0)
numpyro.sample("x", dist.Normal(batch, 1))

guide = AutoSemiDAIS(model, model, None)
def create_plates():
return numpyro.plate("N", 10, subsample_size=5, dim=-1)

local_guide = AutoNormal(model, create_plates=create_plates)
martinjankowiak marked this conversation as resolved.
Show resolved Hide resolved
guide = AutoSemiDAIS(model, model, None, local_guide=local_guide)
svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 10)
samples = guide.sample_posterior(
Expand Down