Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uses init_loc_fn to initialize mixture particles #1612

Merged
merged 8 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 62 additions & 51 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple
from copy import deepcopy
import functools
from functools import partial
from itertools import chain
Expand All @@ -20,8 +21,7 @@
get_parameter_transform,
)
from numpyro.contrib.funsor import config_enumerate, enum
from numpyro.distributions import Distribution, Normal
from numpyro.distributions.constraints import real
from numpyro.distributions import Distribution
from numpyro.distributions.transforms import IdentityTransform
from numpyro.infer.autoguide import AutoGuide
from numpyro.infer.util import _guess_max_plate_nesting, transform_fn
Expand Down Expand Up @@ -102,6 +102,38 @@ def __init__(
enum=True,
**static_kwargs,
):
if isinstance(guide, AutoGuide):
not_comptaible_guides = [
"AutoIAFNormal",
"AutoBNAFNormal",
"AutoDAIS",
"AutoSemiDAIS",
"AutoSurrogateLikelihoodDAIS",
]
guide_name = guide.__class__.__name__
assert guide_name not in not_comptaible_guides, (
f"SteinVI currently not compatible with {guide_name}. "
f"If you have a use case, feel free to open an issue."
)

init_loc_error_message = (
"SteinVI is not compatible with init_to_feasible, init_to_value, "
"and init_to_uniform with radius=0. If you have a use case, "
"feel free to open an issue."
)
if isinstance(guide.init_loc_fn, partial):
init_fn_name = guide.init_loc_fn.func.__name__
if init_fn_name == "init_to_uniform":
assert (
guide.init_loc_fn.keywords.get("radius", None) != 0
), init_loc_error_message
else:
init_fn_name = guide.init_loc_fn.__name__
assert init_fn_name not in [
"init_to_feasible",
"init_to_value",
], init_loc_error_message

self._inference_model = model
self.model = model
self.guide = guide
Expand All @@ -112,7 +144,7 @@ def __init__(
)
self.kernel_fn = kernel_fn
self.static_kwargs = static_kwargs
self.num_particles = num_stein_particles
self.num_stein_particles = num_stein_particles
self.loss_temperature = loss_temperature
self.repulsion_temperature = repulsion_temperature
self.enum = enum
Expand Down Expand Up @@ -167,48 +199,21 @@ def _calc_particle_info(self, uparams, num_particles, start_index=0):
start_index = end_index
return res, end_index

def _find_init_params(self, particle_seed, inner_guide, inner_guide_trace):
def extract_info(site):
nonlocal particle_seed
name = site["name"]
value = site["value"]
constraint = site["kwargs"].get("constraint", real)
transform = get_parameter_transform(site)
if (
isinstance(inner_guide, AutoGuide)
and "_".join((inner_guide.prefix, "loc")) in name
):
site_key, particle_seed = random.split(particle_seed)
unconstrained_shape = transform.inverse_shape(value.shape)
init_value = jnp.expand_dims(
transform.inv(value), 0
) + Normal( # Add gaussian noise
scale=0.1
).sample(
particle_seed, (self.num_particles, *unconstrained_shape)
)
init_value = transform(init_value)

else:
site_fn = site["fn"]
site_args = site["args"]
site_key, particle_seed = random.split(particle_seed)
def _find_init_params(self, particle_seed, inner_guide, model_args, model_kwargs):
def local_trace(key):
guide = deepcopy(inner_guide)

def _reinit(seed):
with handlers.seed(rng_seed=seed):
return site_fn(*site_args)
with handlers.seed(rng_seed=key), handlers.trace() as mixture_trace:
guide(*model_args, **model_kwargs)

init_value = vmap(_reinit)(
random.split(particle_seed, self.num_particles)
)
return init_value, constraint
init_params = {
name: site["value"]
for name, site in mixture_trace.items()
if site.get("type") == "param"
}
return init_params

init_params = {
name: extract_info(site)
for name, site in inner_guide_trace.items()
if site.get("type") == "param"
}
return init_params
return vmap(local_trace)(random.split(particle_seed, self.num_stein_particles))

def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs):
# 0. Separate model and guide parameters, since only guide parameters are updated using Stein
Expand Down Expand Up @@ -352,7 +357,7 @@ def _update_force(attr_force, rep_force, jac):
vmap(single_particle_grad)(
stein_particles, attractive_force, repulsive_force
)
/ self.num_particles
/ self.num_stein_particles
)

# 5. Decompose the monolithic particle forces back to concrete parameter values
Expand All @@ -372,19 +377,25 @@ def init(self, rng_key: KeyArray, *args, **kwargs):
:param kwargs: Keyword arguments to the model / guide.
:return: initial :data:`SteinVIState`
"""
rng_key, kernel_seed, model_seed, guide_seed = random.split(rng_key, 4)
model_init = handlers.seed(self.model, model_seed)
guide_init = handlers.seed(self.guide, guide_seed)
guide_trace = handlers.trace(guide_init).get_trace(
*args, **kwargs, **self.static_kwargs

rng_key, kernel_seed, model_seed, guide_seed, particle_seed = random.split(
rng_key, 5
)

model_init = handlers.seed(self.model, model_seed)
model_trace = handlers.trace(model_init).get_trace(
*args, **kwargs, **self.static_kwargs
)
rng_key, particle_seed = random.split(rng_key)

guide_init_params = self._find_init_params(
particle_seed, self.guide, guide_trace
particle_seed, self.guide, args, kwargs
)

guide_init = handlers.seed(self.guide, guide_seed)
guide_trace = handlers.trace(guide_init).get_trace(
*args, **kwargs, **self.static_kwargs
)

params = {}
transforms = {}
inv_transforms = {}
Expand Down Expand Up @@ -415,7 +426,7 @@ def init(self, rng_key: KeyArray, *args, **kwargs):
"particle_transform", IdentityTransform()
)
if site["name"] in guide_init_params:
pval, _ = guide_init_params[site["name"]]
pval = guide_init_params[site["name"]]
if self.non_mixture_params_fn(site["name"]):
pval = tree_map(lambda x: x[0], pval)
else:
Expand Down
Loading