Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docs/stein mixtures #1605

Merged
merged 4 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
language = "en"
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
Expand Down
2 changes: 1 addition & 1 deletion examples/stein_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def main(args):
guide=stein.guide,
params=stein.get_params(result.state),
num_samples=100,
guide_sites=stein.guide_param_names,
guide_sites=stein.guide_sites,
)
xte, _, _ = normalize(
data.xte, xtr_mean, xtr_std
Expand Down
2 changes: 1 addition & 1 deletion examples/stein_dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def main(args):
guide,
params=results.params,
num_samples=1,
guide_sites=steinvi.guide_param_names,
guide_sites=steinvi.guide_sites,
)
seqs, rev_seqs, lengths = load_data("valid")
pred_notes = pred(
Expand Down
2 changes: 2 additions & 0 deletions numpyro/contrib/einstein/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from numpyro.contrib.einstein.mixture_guide_predictive import MixtureGuidePredictive
from numpyro.contrib.einstein.stein_kernels import (
GraphicalKernel,
IMQKernel,
Expand All @@ -23,4 +24,5 @@
"GraphicalKernel",
"MixtureKernel",
"ProbabilityProductKernel",
"MixtureGuidePredictive",
]
28 changes: 21 additions & 7 deletions numpyro/contrib/einstein/mixture_guide_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,24 @@


class MixtureGuidePredictive:
"""
For single mixture component use numpyro.infer.Predictive.
"""(EXPERIMENTAL INTERFACE) This class constructs the predictive distribution for
:class:`numpyro.contrib.einstein.steinvi.SteinVi`

.. Note:: For single mixture component use numpyro.infer.Predictive.

.. warning::
The `MixtureGuidePredictive` is experimental and will likely be replaced by
:class:`numpyro.infer.util.Predictive` in the future.

:param Callable model: Python callable containing Pyro primitives.
:param Callable guide: Python callable containing Pyro primitives to get posterior samples of sites.
:param Dict params: Dictionary of values for param sites of model/guide
:param Sequence guide_sites: Names of sites that contribute to the Stein mixture.
:param Optional[int] num_samples:
:param Optional[Sequence[str]] return_sites: Sites to return. By default, only sample sites not present
in the guide are returned.
:param str mixture_assignment_sitename: Name of site for mixture component assignment for sites not in the Stein
mixture.
"""

def __init__(
Expand All @@ -25,8 +41,6 @@ def __init__(
guide_sites: Sequence,
num_samples: Optional[int] = None,
return_sites: Optional[Sequence[str]] = None,
infer_discrete: bool = False,
parallel: bool = False,
mixture_assignment_sitename="mixture_assignments",
):
self.model_predictive = partial(
Expand All @@ -37,11 +51,11 @@ def __init__(
},
num_samples=num_samples,
return_sites=return_sites,
infer_discrete=infer_discrete,
parallel=parallel,
infer_discrete=False,
parallel=False,
)
self._batch_shape = (num_samples,)
self.parallel = parallel
self.parallel = False
self.guide_params = {
name: param for name, param in params.items() if name in guide_sites
}
Expand Down
127 changes: 85 additions & 42 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Callable

from jax import grad, jacfwd, numpy as jnp, random, vmap
from jax.random import KeyArray
from jax.tree_util import tree_map

from numpyro import handlers
Expand All @@ -24,6 +25,7 @@
from numpyro.distributions.transforms import IdentityTransform
from numpyro.infer.autoguide import AutoGuide
from numpyro.infer.util import _guess_max_plate_nesting, transform_fn
from numpyro.optim import _NumPyroOptim
from numpyro.util import fori_collect, ravel_pytree

SteinVIState = namedtuple("SteinVIState", ["optim_state", "rng_key"])
Expand All @@ -35,36 +37,68 @@ def _numel(shape):


class SteinVI:
"""Variational inference with stein mixtures.
"""Variational inference with Stein mixtures.

:param model: Python callable with Pyro primitives for the model.

**Example:**

.. doctest::

>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.distributions import constraints
>>> from numpyro.contrib.einstein import MixtureGuidePredictive, SteinVI, RBFKernel

>>> def model(data):
... f = numpyro.sample("latent_fairness", dist.Beta(10, 10))
... with numpyro.plate("N", data.shape[0] if data is not None else 10):
... numpyro.sample("obs", dist.Bernoulli(f), obs=data)

>>> def guide(data):
... alpha_q = numpyro.param("alpha_q", 15., constraint=constraints.positive)
... beta_q = numpyro.param("beta_q", lambda rng_key: random.exponential(rng_key),
... constraint=constraints.positive)
... numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

>>> data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)])
>>> optimizer = numpyro.optim.Adam(step_size=0.0005)
>>> stein = SteinVI(model, guide, optimizer, kernel_fn=RBFKernel())
>>> stein_result = stein.run(random.PRNGKey(0), 2000, data)
>>> params = stein_result.params
>>> # use guide to make predictive
>>> predictive = MixtureGuidePredictive(model, guide, params, num_samples=1000, guide_sites=stein.guide_sites)
>>> samples = predictive(random.PRNGKey(1), data=None)

:param Callable model: Python callable with Pyro primitives for the model.
:param guide: Python callable with Pyro primitives for the guide
(recognition network).
:param optim: an instance of :class:`~numpyro.optim._NumpyroOptim`.
:param kernel_fn: Function that produces a logarithm of the statistical kernel to use with Stein inference
:param num_stein_particles: number of particles for Stein inference.
(More particles give more mixture components and therefore likely capture more of the posterior distribution)
:param num_elbo_particles: number of particles for to approximate the attractive force gradient.
:param _NumPyroOptim optim: An instance of :class:`~numpyro.optim._NumpyroOptim`.
:param SteinKernel kernel_fn: Function that produces a logarithm of the statistical kernel to use with Stein mixture
inference.
:param num_stein_particles: Number of particles (i.e., mixture components) in the Stein mixture.
:param num_elbo_particles: Number of Monte Carlo draws used to approximate the attractive force gradient.
(More particles give better gradient approximations)
:param loss_temperature: scaling of loss factor
:param repulsion_temperature: scaling of repulsive forces (Non-linear Stein)
:param classic_guide_param_fn: predicate on names of parameters in guide which should be optimized classically
without Stein (E.g. parameters for large normal networks or other transformation)
:param static_kwargs: Static keyword arguments for the model / guide, i.e. arguments
that remain constant during fitting.
:param Float loss_temperature: Scaling factor of the attractive force.
:param Float repulsion_temperature: Scaling factor of the repulsive force (Non-linear Stein)
:param Callable non_mixture_guide_param_fn: predicate on names of parameters in guide which should be optimized
classically without Stein (E.g. parameters for large normal networks or other transformation)
:param static_kwargs: Static keyword arguments for the model / guide, i.e. arguments that remain constant
during inference.
"""

def __init__(
self,
model,
guide,
optim,
model: Callable,
guide: Callable,
optim: _NumPyroOptim,
kernel_fn: SteinKernel,
num_stein_particles: int = 10,
num_elbo_particles: int = 10,
loss_temperature: float = 1.0,
repulsion_temperature: float = 1.0,
classic_guide_params_fn: Callable[[str], bool] = lambda name: False,
non_mixture_guide_params_fn: Callable[[str], bool] = lambda name: False,
enum=True,
**static_kwargs,
):
Expand All @@ -82,8 +116,8 @@ def __init__(
self.loss_temperature = loss_temperature
self.repulsion_temperature = repulsion_temperature
self.enum = enum
self.model_params_fn = classic_guide_params_fn
self.guide_param_names = None
self.non_mixture_params_fn = non_mixture_guide_params_fn
self.guide_sites = None
self.constrain_fn = None
self.uconstrain_fn = None
self.particle_transform_fn = None
Expand Down Expand Up @@ -178,14 +212,17 @@ def _reinit(seed):

def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs):
# 0. Separate model and guide parameters, since only guide parameters are updated using Stein
model_uparams = {
p: v
for p, v in unconstr_params.items()
if p not in self.guide_param_names or self.model_params_fn(p)
}
non_mixture_uparams = (
{ # Includes any marked guide parameters and all model parameters
p: v
for p, v in unconstr_params.items()
if p not in self.guide_sites or self.non_mixture_params_fn(p)
}
)
stein_uparams = {
p: v for p, v in unconstr_params.items() if p not in model_uparams
p: v for p, v in unconstr_params.items() if p not in non_mixture_uparams
}

# 1. Collect each guide parameter into monolithic particles that capture correlations
# between parameter values across each individual particle
stein_particles, unravel_pytree, unravel_pytree_batched = batch_ravel_pytree(
Expand All @@ -197,7 +234,9 @@ def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs):
attractive_key, classic_key = random.split(rng_key)

# 2. Calculate gradients for each particle
def kernel_particles_loss_fn(rng_key, particles):
def kernel_particles_loss_fn(
rng_key, particles
): # TODO: rewrite using def to utilize jax caching
particle_keys = random.split(rng_key, self.stein_loss.stein_num_particles)
grads = vmap(
lambda i: grad(
Expand All @@ -215,7 +254,7 @@ def kernel_particles_loss_fn(rng_key, particles):
select_index=i,
model_args=args,
model_kwargs=kwargs,
param_map=self.constrain_fn(model_uparams),
param_map=self.constrain_fn(non_mixture_uparams),
)
)(
random.split(
Expand All @@ -237,13 +276,16 @@ def particle_transform_fn(particle):
ctparticle, _ = ravel_pytree(ctparams)
return tparticle, ctparticle

# 2.1 Lift particles to constraint space
tstein_particles, ctstein_particles = vmap(particle_transform_fn)(
stein_particles
)

# 2.2 Compute particle gradients (for attractive force)
particle_ljp_grads = kernel_particles_loss_fn(attractive_key, ctstein_particles)

classic_param_grads = grad(
# 2.2 Compute non-mixture parameter gradients
non_mixture_param_grads = grad(
lambda cps: -self.stein_loss.loss(
classic_key,
self.constrain_fn(cps),
Expand All @@ -253,14 +295,14 @@ def particle_transform_fn(particle):
*args,
**kwargs,
)
)(model_uparams)
)(non_mixture_uparams)

# 3. Calculate kernel on monolithic particle
kernel = self.kernel_fn.compute( # TODO: Fix to use Stein loss
# 3. Calculate kernel of particles
kernel = self.kernel_fn.compute(
stein_particles, particle_info, kernel_particles_loss_fn
)

# 4. Calculate the attractive force and repulsive force on the monolithic particles
# 4. Calculate the attractive force and repulsive force on the particles
attractive_force = vmap(
lambda y: jnp.sum(
vmap(
Expand Down Expand Up @@ -317,16 +359,17 @@ def _update_force(attr_force, rep_force, jac):
stein_param_grads = unravel_pytree_batched(particle_grads)

# 6. Return loss and gradients (based on parameter forces)
res_grads = tree_map(lambda x: -x, {**classic_param_grads, **stein_param_grads})
res_grads = tree_map(
lambda x: -x, {**non_mixture_param_grads, **stein_param_grads}
)
return jnp.linalg.norm(particle_grads), res_grads

def init(self, rng_key, *args, **kwargs):
"""
:param jax.random.PRNGKey rng_key: random number generator seed.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
def init(self, rng_key: KeyArray, *args, **kwargs):
"""Register random variable transformations, constraints and determine initialize positions of the particles.

:param KeyArray rng_key: Random number generator seed.
:param args: Arguments to the model / guide.
:param kwargs: Keyword arguments to the model / guide.
:return: initial :data:`SteinVIState`
"""
rng_key, kernel_seed, model_seed, guide_seed = random.split(rng_key, 4)
Expand Down Expand Up @@ -373,7 +416,7 @@ def init(self, rng_key, *args, **kwargs):
)
if site["name"] in guide_init_params:
pval, _ = guide_init_params[site["name"]]
if self.model_params_fn(site["name"]):
if self.non_mixture_params_fn(site["name"]):
pval = tree_map(lambda x: x[0], pval)
else:
pval = site["value"]
Expand All @@ -384,7 +427,7 @@ def init(self, rng_key, *args, **kwargs):
if should_enum:
mpn = _guess_max_plate_nesting(model_trace)
self._inference_model = enum(config_enumerate(self.model), -mpn - 1)
self.guide_param_names = guide_param_names
self.guide_sites = guide_param_names
self.constrain_fn = partial(transform_fn, inv_transforms)
self.uconstrain_fn = partial(transform_fn, transforms)
self.particle_transforms = particle_transforms
Expand Down