diff --git a/docs/source/conf.py b/docs/source/conf.py index 0d4da39fb..1a6e68a5c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -113,7 +113,7 @@ # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = None +language = "en" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. diff --git a/examples/stein_bnn.py b/examples/stein_bnn.py index 280466dc2..f113bc435 100644 --- a/examples/stein_bnn.py +++ b/examples/stein_bnn.py @@ -159,7 +159,7 @@ def main(args): guide=stein.guide, params=stein.get_params(result.state), num_samples=100, - guide_sites=stein.guide_param_names, + guide_sites=stein.guide_sites, ) xte, _, _ = normalize( data.xte, xtr_mean, xtr_std diff --git a/examples/stein_dmm.py b/examples/stein_dmm.py index 3b013bafb..140e886aa 100644 --- a/examples/stein_dmm.py +++ b/examples/stein_dmm.py @@ -317,7 +317,7 @@ def main(args): guide, params=results.params, num_samples=1, - guide_sites=steinvi.guide_param_names, + guide_sites=steinvi.guide_sites, ) seqs, rev_seqs, lengths = load_data("valid") pred_notes = pred( diff --git a/numpyro/contrib/einstein/__init__.py b/numpyro/contrib/einstein/__init__.py index 774b92562..57990a488 100644 --- a/numpyro/contrib/einstein/__init__.py +++ b/numpyro/contrib/einstein/__init__.py @@ -1,6 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from numpyro.contrib.einstein.mixture_guide_predictive import MixtureGuidePredictive from numpyro.contrib.einstein.stein_kernels import ( GraphicalKernel, IMQKernel, @@ -23,4 +24,5 @@ "GraphicalKernel", "MixtureKernel", "ProbabilityProductKernel", + "MixtureGuidePredictive", ] diff --git a/numpyro/contrib/einstein/mixture_guide_predictive.py b/numpyro/contrib/einstein/mixture_guide_predictive.py index 4fb11fde0..3eca7513b 100644 --- a/numpyro/contrib/einstein/mixture_guide_predictive.py +++ b/numpyro/contrib/einstein/mixture_guide_predictive.py @@ -13,8 +13,24 @@ class MixtureGuidePredictive: - """ - For single mixture component use numpyro.infer.Predictive. + """(EXPERIMENTAL INTERFACE) This class constructs the predictive distribution for + :class:`numpyro.contrib.einstein.steinvi.SteinVi` + + .. Note:: For single mixture component use numpyro.infer.Predictive. + + .. warning:: + The `MixtureGuidePredictive` is experimental and will likely be replaced by + :class:`numpyro.infer.util.Predictive` in the future. + + :param Callable model: Python callable containing Pyro primitives. + :param Callable guide: Python callable containing Pyro primitives to get posterior samples of sites. + :param Dict params: Dictionary of values for param sites of model/guide + :param Sequence guide_sites: Names of sites that contribute to the Stein mixture. + :param Optional[int] num_samples: + :param Optional[Sequence[str]] return_sites: Sites to return. By default, only sample sites not present + in the guide are returned. + :param str mixture_assignment_sitename: Name of site for mixture component assignment for sites not in the Stein + mixture. """ def __init__( @@ -25,8 +41,6 @@ def __init__( guide_sites: Sequence, num_samples: Optional[int] = None, return_sites: Optional[Sequence[str]] = None, - infer_discrete: bool = False, - parallel: bool = False, mixture_assignment_sitename="mixture_assignments", ): self.model_predictive = partial( @@ -37,11 +51,11 @@ def __init__( }, num_samples=num_samples, return_sites=return_sites, - infer_discrete=infer_discrete, - parallel=parallel, + infer_discrete=False, + parallel=False, ) self._batch_shape = (num_samples,) - self.parallel = parallel + self.parallel = False self.guide_params = { name: param for name, param in params.items() if name in guide_sites } diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py index 8d9701d40..424d09f79 100644 --- a/numpyro/contrib/einstein/steinvi.py +++ b/numpyro/contrib/einstein/steinvi.py @@ -9,6 +9,7 @@ from typing import Callable from jax import grad, jacfwd, numpy as jnp, random, vmap +from jax.random import KeyArray from jax.tree_util import tree_map from numpyro import handlers @@ -24,6 +25,7 @@ from numpyro.distributions.transforms import IdentityTransform from numpyro.infer.autoguide import AutoGuide from numpyro.infer.util import _guess_max_plate_nesting, transform_fn +from numpyro.optim import _NumPyroOptim from numpyro.util import fori_collect, ravel_pytree SteinVIState = namedtuple("SteinVIState", ["optim_state", "rng_key"]) @@ -35,36 +37,68 @@ def _numel(shape): class SteinVI: - """Variational inference with stein mixtures. + """Variational inference with Stein mixtures. - :param model: Python callable with Pyro primitives for the model. + + **Example:** + + .. doctest:: + + >>> from jax import random + >>> import jax.numpy as jnp + >>> import numpyro + >>> import numpyro.distributions as dist + >>> from numpyro.distributions import constraints + >>> from numpyro.contrib.einstein import MixtureGuidePredictive, SteinVI, RBFKernel + + >>> def model(data): + ... f = numpyro.sample("latent_fairness", dist.Beta(10, 10)) + ... with numpyro.plate("N", data.shape[0] if data is not None else 10): + ... numpyro.sample("obs", dist.Bernoulli(f), obs=data) + + >>> def guide(data): + ... alpha_q = numpyro.param("alpha_q", 15., constraint=constraints.positive) + ... beta_q = numpyro.param("beta_q", lambda rng_key: random.exponential(rng_key), + ... constraint=constraints.positive) + ... numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q)) + + >>> data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)]) + >>> optimizer = numpyro.optim.Adam(step_size=0.0005) + >>> stein = SteinVI(model, guide, optimizer, kernel_fn=RBFKernel()) + >>> stein_result = stein.run(random.PRNGKey(0), 2000, data) + >>> params = stein_result.params + >>> # use guide to make predictive + >>> predictive = MixtureGuidePredictive(model, guide, params, num_samples=1000, guide_sites=stein.guide_sites) + >>> samples = predictive(random.PRNGKey(1), data=None) + + :param Callable model: Python callable with Pyro primitives for the model. :param guide: Python callable with Pyro primitives for the guide (recognition network). - :param optim: an instance of :class:`~numpyro.optim._NumpyroOptim`. - :param kernel_fn: Function that produces a logarithm of the statistical kernel to use with Stein inference - :param num_stein_particles: number of particles for Stein inference. - (More particles give more mixture components and therefore likely capture more of the posterior distribution) - :param num_elbo_particles: number of particles for to approximate the attractive force gradient. + :param _NumPyroOptim optim: An instance of :class:`~numpyro.optim._NumpyroOptim`. + :param SteinKernel kernel_fn: Function that produces a logarithm of the statistical kernel to use with Stein mixture + inference. + :param num_stein_particles: Number of particles (i.e., mixture components) in the Stein mixture. + :param num_elbo_particles: Number of Monte Carlo draws used to approximate the attractive force gradient. (More particles give better gradient approximations) - :param loss_temperature: scaling of loss factor - :param repulsion_temperature: scaling of repulsive forces (Non-linear Stein) - :param classic_guide_param_fn: predicate on names of parameters in guide which should be optimized classically - without Stein (E.g. parameters for large normal networks or other transformation) - :param static_kwargs: Static keyword arguments for the model / guide, i.e. arguments - that remain constant during fitting. + :param Float loss_temperature: Scaling factor of the attractive force. + :param Float repulsion_temperature: Scaling factor of the repulsive force (Non-linear Stein) + :param Callable non_mixture_guide_param_fn: predicate on names of parameters in guide which should be optimized + classically without Stein (E.g. parameters for large normal networks or other transformation) + :param static_kwargs: Static keyword arguments for the model / guide, i.e. arguments that remain constant + during inference. """ def __init__( self, - model, - guide, - optim, + model: Callable, + guide: Callable, + optim: _NumPyroOptim, kernel_fn: SteinKernel, num_stein_particles: int = 10, num_elbo_particles: int = 10, loss_temperature: float = 1.0, repulsion_temperature: float = 1.0, - classic_guide_params_fn: Callable[[str], bool] = lambda name: False, + non_mixture_guide_params_fn: Callable[[str], bool] = lambda name: False, enum=True, **static_kwargs, ): @@ -82,8 +116,8 @@ def __init__( self.loss_temperature = loss_temperature self.repulsion_temperature = repulsion_temperature self.enum = enum - self.model_params_fn = classic_guide_params_fn - self.guide_param_names = None + self.non_mixture_params_fn = non_mixture_guide_params_fn + self.guide_sites = None self.constrain_fn = None self.uconstrain_fn = None self.particle_transform_fn = None @@ -178,14 +212,17 @@ def _reinit(seed): def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs): # 0. Separate model and guide parameters, since only guide parameters are updated using Stein - model_uparams = { - p: v - for p, v in unconstr_params.items() - if p not in self.guide_param_names or self.model_params_fn(p) - } + non_mixture_uparams = ( + { # Includes any marked guide parameters and all model parameters + p: v + for p, v in unconstr_params.items() + if p not in self.guide_sites or self.non_mixture_params_fn(p) + } + ) stein_uparams = { - p: v for p, v in unconstr_params.items() if p not in model_uparams + p: v for p, v in unconstr_params.items() if p not in non_mixture_uparams } + # 1. Collect each guide parameter into monolithic particles that capture correlations # between parameter values across each individual particle stein_particles, unravel_pytree, unravel_pytree_batched = batch_ravel_pytree( @@ -197,7 +234,9 @@ def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs): attractive_key, classic_key = random.split(rng_key) # 2. Calculate gradients for each particle - def kernel_particles_loss_fn(rng_key, particles): + def kernel_particles_loss_fn( + rng_key, particles + ): # TODO: rewrite using def to utilize jax caching particle_keys = random.split(rng_key, self.stein_loss.stein_num_particles) grads = vmap( lambda i: grad( @@ -215,7 +254,7 @@ def kernel_particles_loss_fn(rng_key, particles): select_index=i, model_args=args, model_kwargs=kwargs, - param_map=self.constrain_fn(model_uparams), + param_map=self.constrain_fn(non_mixture_uparams), ) )( random.split( @@ -237,13 +276,16 @@ def particle_transform_fn(particle): ctparticle, _ = ravel_pytree(ctparams) return tparticle, ctparticle + # 2.1 Lift particles to constraint space tstein_particles, ctstein_particles = vmap(particle_transform_fn)( stein_particles ) + # 2.2 Compute particle gradients (for attractive force) particle_ljp_grads = kernel_particles_loss_fn(attractive_key, ctstein_particles) - classic_param_grads = grad( + # 2.2 Compute non-mixture parameter gradients + non_mixture_param_grads = grad( lambda cps: -self.stein_loss.loss( classic_key, self.constrain_fn(cps), @@ -253,14 +295,14 @@ def particle_transform_fn(particle): *args, **kwargs, ) - )(model_uparams) + )(non_mixture_uparams) - # 3. Calculate kernel on monolithic particle - kernel = self.kernel_fn.compute( # TODO: Fix to use Stein loss + # 3. Calculate kernel of particles + kernel = self.kernel_fn.compute( stein_particles, particle_info, kernel_particles_loss_fn ) - # 4. Calculate the attractive force and repulsive force on the monolithic particles + # 4. Calculate the attractive force and repulsive force on the particles attractive_force = vmap( lambda y: jnp.sum( vmap( @@ -317,16 +359,17 @@ def _update_force(attr_force, rep_force, jac): stein_param_grads = unravel_pytree_batched(particle_grads) # 6. Return loss and gradients (based on parameter forces) - res_grads = tree_map(lambda x: -x, {**classic_param_grads, **stein_param_grads}) + res_grads = tree_map( + lambda x: -x, {**non_mixture_param_grads, **stein_param_grads} + ) return jnp.linalg.norm(particle_grads), res_grads - def init(self, rng_key, *args, **kwargs): - """ - :param jax.random.PRNGKey rng_key: random number generator seed. - :param args: arguments to the model / guide (these can possibly vary during - the course of fitting). - :param kwargs: keyword arguments to the model / guide (these can possibly vary - during the course of fitting). + def init(self, rng_key: KeyArray, *args, **kwargs): + """Register random variable transformations, constraints and determine initialize positions of the particles. + + :param KeyArray rng_key: Random number generator seed. + :param args: Arguments to the model / guide. + :param kwargs: Keyword arguments to the model / guide. :return: initial :data:`SteinVIState` """ rng_key, kernel_seed, model_seed, guide_seed = random.split(rng_key, 4) @@ -373,7 +416,7 @@ def init(self, rng_key, *args, **kwargs): ) if site["name"] in guide_init_params: pval, _ = guide_init_params[site["name"]] - if self.model_params_fn(site["name"]): + if self.non_mixture_params_fn(site["name"]): pval = tree_map(lambda x: x[0], pval) else: pval = site["value"] @@ -384,7 +427,7 @@ def init(self, rng_key, *args, **kwargs): if should_enum: mpn = _guess_max_plate_nesting(model_trace) self._inference_model = enum(config_enumerate(self.model), -mpn - 1) - self.guide_param_names = guide_param_names + self.guide_sites = guide_param_names self.constrain_fn = partial(transform_fn, inv_transforms) self.uconstrain_fn = partial(transform_fn, transforms) self.particle_transforms = particle_transforms