Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support model without global variables in AutoSemiDAIS #1610

Merged
merged 4 commits into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 36 additions & 24 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,8 +1119,8 @@ def local_model(theta):
numpyro.sample("obs", dist.Normal(0.0, tau), obs=jnp.ones(2))

model = lambda: local_model(global_model())
base_guide = AutoNormal(global_model)
guide = AutoSemiDAIS(model, local_model, base_guide, K=4)
global_guide = AutoNormal(global_model)
guide = AutoSemiDAIS(model, local_model, global_guide, K=4)
svi = SVI(model, guide, ...)

# sample posterior for particular data subset {3, 7}
Expand All @@ -1131,8 +1131,9 @@ def local_model(theta):
:param callable local_model: The portion of `model` that includes the local latent variables only.
The signature of `local_model` should be the return type of the global model with global latent
variables only.
:param callable base_guide: A guide for the global latent variables, e.g. an autoguide.
:param callable global_guide: A guide for the global latent variables, e.g. an autoguide.
The return type should be a dictionary of latent sample sites names and corresponding samples.
If there is no global variable in the model, we can set this to None.
:param str prefix: A prefix that will be prefixed to all internal sites.
:param int K: A positive integer that controls the number of HMC steps used.
Defaults to 4.
Expand All @@ -1150,7 +1151,7 @@ def __init__(
self,
model,
local_model,
base_guide,
global_guide,
*,
prefix="auto",
K=4,
Expand All @@ -1175,7 +1176,7 @@ def __init__(
raise ValueError("init_scale must be positive.")

self.local_model = local_model
self.base_guide = base_guide
self.global_guide = global_guide
self.eta_init = eta_init
self.eta_max = eta_max
self.gamma_init = gamma_init
Expand Down Expand Up @@ -1237,25 +1238,30 @@ def _setup_prototype(self, *args, **kwargs):
self._local_latent_dim = jnp.size(local_init_latent) // plate_subsample_size
self._local_plate = (plate_name, plate_full_size, plate_subsample_size)

rng_key = numpyro.prng_key()
with handlers.block(), handlers.seed(rng_seed=rng_key):
global_output = self.base_guide.model(*args, **kwargs)
if self.global_guide is not None:
with handlers.block(), handlers.seed(rng_seed=0):
local_args = (self.global_guide.model(*args, **kwargs),)
local_kwargs = {}
else:
local_args = args
local_kwargs = kwargs.copy()

with handlers.block():
local_kwargs["_subsample_idx"] = {
plate_name: subsample_plates[plate_name]["value"]
}
(
_,
self._local_potential_fn_gen,
self._local_postprecess_fn,
_,
) = initialize_model(
numpyro.prng_key(),
random.PRNGKey(0),
partial(_subsample_model, self.local_model),
init_strategy=self.init_loc_fn,
dynamic_args=True,
model_args=(global_output,),
model_kwargs={
"_subsample_idx": {
plate_name: subsample_plates[plate_name]["value"]
}
},
model_args=local_args,
model_kwargs=local_kwargs,
)

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -1309,12 +1315,19 @@ def fn(x):

return fn

global_latents = self.base_guide(*args, **kwargs)
rng_key = numpyro.prng_key()
with handlers.block(), handlers.seed(rng_seed=rng_key), handlers.substitute(
data=global_latents
):
global_output = self.base_guide.model(*args, **kwargs)
if self.global_guide is not None:
global_latents = self.global_guide(*args, **kwargs)
rng_key = numpyro.prng_key()
with handlers.block(), handlers.seed(rng_seed=rng_key), handlers.substitute(
data=global_latents
):
global_outputs = self.global_guide.model(*args, **kwargs)
local_args = (global_ouputs,)
local_kwargs = {}
else:
global_latents = {}
local_args = args
local_kwargs = kwargs.copy()

plate_name, N, subsample_size = self._local_plate
D, K = self._local_latent_dim, self.K
Expand Down Expand Up @@ -1383,9 +1396,8 @@ def base_z_dist_log_prob(x):
infer={"is_auxiliary": True},
)

local_log_density = make_local_log_density(
global_output, _subsample_idx={plate_name: idx}
)
local_kwargs["_subsample_idx"] = {plate_name: idx}
local_log_density = make_local_log_density(*local_args, **local_kwargs)

def scan_body(carry, eps_beta):
eps, beta = eps_beta
Expand Down
17 changes: 17 additions & 0 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,23 @@ def model():
assert samples["sigma"].shape == (5,) and samples["log_sigma"].shape == (5, 2)


def test_autosemidais_local_only():
data = jnp.linspace(0, 1, 10)

def model():
with numpyro.plate("N", 10, subsample_size=5, dim=-1):
batch = numpyro.subsample(data, event_dim=0)
numpyro.sample("x", dist.Normal(batch, 1))

guide = AutoSemiDAIS(model, model, None)
svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 10)
samples = guide.sample_posterior(
random.PRNGKey(1), svi_result.params, sample_shape=(100,)
)
assert samples["x"].shape == (100, 5)


def test_autosemidais_inadmissible_smoke():
def global_model():
return numpyro.sample("theta", dist.Normal(0, 1))
Expand Down