diff --git a/docs/source/contrib.rst b/docs/source/contrib.rst index 97ebe3e1a..a7b69db6a 100644 --- a/docs/source/contrib.rst +++ b/docs/source/contrib.rst @@ -23,12 +23,11 @@ SVGD is well suited for capturing correlations between latent variables as a par The technique preserves the scalability of traditional VI approaches while offering the flexibility and modeling scope of methods such as Markov chain Monte Carlo (MCMC). SVGD is good at capturing multi-modality [3][4]. -``numpyro.contrib.einstein`` is a framework for particle-based inference using the ELBO-within-Stein algorithm. +``numpyro.contrib.einstein`` is a framework for particle-based inference using the Stein mixture algorithm. The framework works on Stein mixtures, a restricted mixture of guide programs parameterized by Stein particles. Similarly to how SVGD works, Stein mixtures can approximate model posteriors by moving the Stein particles according to the Stein forces. Because the Stein particles parameterize a guide, they capture a neighborhood rather than a -single point. This property means Stein mixtures significantly reduce the number of particles needed to represent -high dimensional models. +single point. ``numpyro.contrib.einstein`` mimics the interface from ``numpyro.infer.svi``, so trying SteinVI requires minimal change to the code for existing models inferred with SVI. For primary usage, see the @@ -40,9 +39,8 @@ The framework currently supports several kernels, including: - `LinearKernel` - `RandomFeatureKernel` - `MixtureKernel` -- `PrecondMatrixKernel` -- `HessianPrecondMatrix` - `GraphicalKernel` +- `ProbabilityProductKernel` For example, usage see: @@ -68,10 +66,11 @@ SteinVI Interface SteinVI Kernels --------------- -.. autoclass:: numpyro.contrib.einstein.kernels.RBFKernel -.. autoclass:: numpyro.contrib.einstein.kernels.LinearKernel -.. autoclass:: numpyro.contrib.einstein.kernels.RandomFeatureKernel -.. autoclass:: numpyro.contrib.einstein.kernels.MixtureKernel -.. autoclass:: numpyro.contrib.einstein.kernels.PrecondMatrixKernel -.. autoclass:: numpyro.contrib.einstein.kernels.GraphicalKernel +.. autoclass:: numpyro.contrib.einstein.stein_kernels.RBFKernel +.. autoclass:: numpyro.contrib.einstein.stein_kernels.LinearKernel +.. autoclass:: numpyro.contrib.einstein.stein_kernels.RandomFeatureKernel +.. autoclass:: numpyro.contrib.einstein.stein_kernels.MixtureKernel +.. autoclass:: numpyro.contrib.einstein.stein_kernels.GraphicalKernel +.. autoclass:: numpyro.contrib.einstein.stein_kernels.ProbabilityProductKernel + diff --git a/examples/stein_bnn.py b/examples/stein_bnn.py index 7f115c418..280466dc2 100644 --- a/examples/stein_bnn.py +++ b/examples/stein_bnn.py @@ -23,11 +23,13 @@ import jax.numpy as jnp import numpyro -from numpyro.contrib.einstein import RBFKernel, SteinVI +from numpyro import deterministic +from numpyro.contrib.einstein import IMQKernel, SteinVI +from numpyro.contrib.einstein.mixture_guide_predictive import MixtureGuidePredictive from numpyro.distributions import Gamma, Normal from numpyro.examples.datasets import BOSTON_HOUSING, load_dataset -from numpyro.infer import Predictive, Trace_ELBO, init_to_uniform -from numpyro.infer.autoguide import AutoDelta +from numpyro.infer import init_to_uniform +from numpyro.infer.autoguide import AutoNormal from numpyro.optim import Adagrad DataState = namedtuple("data", ["xtr", "xte", "ytr", "yte"]) @@ -36,7 +38,7 @@ def load_data() -> DataState: _, fetch = load_dataset(BOSTON_HOUSING, shuffle=False) x, y = fetch() - xtr, xte, ytr, yte = train_test_split(x, y, train_size=0.90) + xtr, xte, ytr, yte = train_test_split(x, y, train_size=0.90, random_state=1) return DataState(*map(partial(jnp.array, dtype=float), (xtr, xte, ytr, yte))) @@ -105,10 +107,12 @@ def model(x, y=None, hidden_dim=50, subsample_size=100): else: batch_y = y + loc_y = deterministic("y_pred", jnp.maximum(batch_x @ w1 + b1, 0) @ w2 + b2) + numpyro.sample( "y", Normal( - jnp.maximum(batch_x @ w1 + b1, 0) @ w2 + b2, 1.0 / jnp.sqrt(prec_obs) + loc_y, 1.0 / jnp.sqrt(prec_obs) ), # 1 hidden layer with ReLU activation obs=batch_y, ) @@ -124,14 +128,17 @@ def main(args): rng_key, inf_key = random.split(inf_key) + guide = AutoNormal(model, init_loc_fn=partial(init_to_uniform, radius=0.1)) + stein = SteinVI( model, - AutoDelta(model, init_loc_fn=partial(init_to_uniform, radius=0.1)), + guide, Adagrad(0.05), - Trace_ELBO(20), # estimate elbo with 20 particles (not stein particles!) - RBFKernel(), + IMQKernel(), + # ProbabilityProductKernel(guide=guide, scale=1.), repulsion_temperature=args.repulsion, - num_particles=args.num_particles, + num_stein_particles=args.num_stein_particles, + num_elbo_particles=args.num_elbo_particles, ) start = time() @@ -147,33 +154,31 @@ def main(args): ) time_taken = time() - start - pred = Predictive( + pred = MixtureGuidePredictive( model, guide=stein.guide, params=stein.get_params(result.state), - num_samples=200, - batch_ndims=1, # stein particle dimension + num_samples=100, + guide_sites=stein.guide_param_names, ) xte, _, _ = normalize( data.xte, xtr_mean, xtr_std ) # use train data statistics when accessing generalization preds = pred( pred_key, xte, subsample_size=xte.shape[0], hidden_dim=args.hidden_dim - )["y"] + )["y_pred"] - y_pred = jnp.mean(preds, 1) * ytr_std + ytr_mean + y_pred = preds * ytr_std + ytr_mean rmse = jnp.sqrt(jnp.mean((y_pred.mean(0) - data.yte) ** 2)) print(rf"Time taken: {datetime.timedelta(seconds=int(time_taken))}") print(rf"RMSE: {rmse:.2f}") # compute mean prediction and confidence interval around median - mean_prediction = jnp.mean(y_pred, 0) + mean_prediction = y_pred.mean(0) ran = np.arange(mean_prediction.shape[0]) - percentiles = np.percentile( - preds.reshape(-1, xte.shape[0]) * ytr_std + ytr_mean, [5.0, 95.0], axis=0 - ) + percentiles = np.percentile(preds * ytr_std + ytr_mean, [5.0, 95.0], axis=0) # make plots fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True) @@ -199,7 +204,8 @@ def main(args): parser.add_argument("--max-iter", type=int, default=1000) parser.add_argument("--repulsion", type=float, default=1.0) parser.add_argument("--verbose", type=bool, default=True) - parser.add_argument("--num-particles", type=int, default=100) + parser.add_argument("--num-elbo-particles", type=int, default=50) + parser.add_argument("--num-stein-particles", type=int, default=5) parser.add_argument("--progress-bar", type=bool, default=True) parser.add_argument("--rng-key", type=int, default=142) parser.add_argument("--device", default="cpu", choices=["gpu", "cpu"]) diff --git a/examples/stein_dmm.py b/examples/stein_dmm.py index f8aa88dc0..3b013bafb 100644 --- a/examples/stein_dmm.py +++ b/examples/stein_dmm.py @@ -28,10 +28,10 @@ import numpyro from numpyro.contrib.einstein import SteinVI -from numpyro.contrib.einstein.kernels import RBFKernel +from numpyro.contrib.einstein.mixture_guide_predictive import MixtureGuidePredictive +from numpyro.contrib.einstein.stein_kernels import RBFKernel import numpyro.distributions as dist from numpyro.examples.datasets import JSB_CHORALES, load_dataset -from numpyro.infer import Predictive, Trace_ELBO from numpyro.optim import optax_to_numpyro @@ -293,17 +293,17 @@ def vis_tune(i, tunes, lengths, name="stein_dmm.pdf"): def main(args): inf_key, pred_key = random.split(random.PRNGKey(seed=args.rng_seed), 2) - vi = SteinVI( + steinvi = SteinVI( model, guide, optax_to_numpyro(chain(adam(1e-2))), - Trace_ELBO(), RBFKernel(), - num_particles=args.num_particles, + num_elbo_particles=args.num_elbo_particles, + num_stein_particles=args.num_stein_particles, ) seqs, rev_seqs, lengths = load_data() - results = vi.run( + results = steinvi.run( inf_key, args.max_iter, seqs, @@ -312,11 +312,12 @@ def main(args): gru_dim=args.gru_dim, subsample_size=args.subsample_size, ) - pred = Predictive( + pred = MixtureGuidePredictive( model, + guide, params=results.params, num_samples=1, - batch_ndims=1, + guide_sites=steinvi.guide_param_names, ) seqs, rev_seqs, lengths = load_data("valid") pred_notes = pred( @@ -328,11 +329,12 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--subsample-size", type=int, default=77) - parser.add_argument("--max-iter", type=int, default=1000) + parser.add_argument("--subsample-size", type=int, default=10) + parser.add_argument("--max-iter", type=int, default=100) parser.add_argument("--repulsion", type=float, default=1.0) parser.add_argument("--verbose", type=bool, default=True) - parser.add_argument("--num-particles", type=int, default=5) + parser.add_argument("--num-stein-particles", type=int, default=5) + parser.add_argument("--num-elbo-particles", type=int, default=5) parser.add_argument("--progress-bar", type=bool, default=True) parser.add_argument("--gru-dim", type=int, default=150) parser.add_argument("--rng-key", type=int, default=142) diff --git a/numpyro/contrib/einstein/__init__.py b/numpyro/contrib/einstein/__init__.py index 9b98c3dfe..774b92562 100644 --- a/numpyro/contrib/einstein/__init__.py +++ b/numpyro/contrib/einstein/__init__.py @@ -1,26 +1,26 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from numpyro.contrib.einstein.kernels import ( +from numpyro.contrib.einstein.stein_kernels import ( GraphicalKernel, - HessianPrecondMatrix, IMQKernel, LinearKernel, - PrecondMatrix, - PrecondMatrixKernel, + MixtureKernel, + ProbabilityProductKernel, RandomFeatureKernel, RBFKernel, ) +from numpyro.contrib.einstein.stein_loss import SteinLoss from numpyro.contrib.einstein.steinvi import SteinVI __all__ = [ "SteinVI", + "SteinLoss", "RBFKernel", - "PrecondMatrix", "IMQKernel", "LinearKernel", "RandomFeatureKernel", - "HessianPrecondMatrix", "GraphicalKernel", - "PrecondMatrixKernel", + "MixtureKernel", + "ProbabilityProductKernel", ] diff --git a/numpyro/contrib/einstein/mixture_guide_predictive.py b/numpyro/contrib/einstein/mixture_guide_predictive.py new file mode 100644 index 000000000..4fb11fde0 --- /dev/null +++ b/numpyro/contrib/einstein/mixture_guide_predictive.py @@ -0,0 +1,108 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from functools import partial +from typing import Callable, Dict, Optional, Sequence + +from jax import numpy as jnp, random, tree_map, vmap +from jax.tree_util import tree_flatten + +from numpyro.handlers import substitute +from numpyro.infer import Predictive +from numpyro.infer.util import _predictive + + +class MixtureGuidePredictive: + """ + For single mixture component use numpyro.infer.Predictive. + """ + + def __init__( + self, + model: Callable, + guide: Callable, + params: Dict, + guide_sites: Sequence, + num_samples: Optional[int] = None, + return_sites: Optional[Sequence[str]] = None, + infer_discrete: bool = False, + parallel: bool = False, + mixture_assignment_sitename="mixture_assignments", + ): + self.model_predictive = partial( + Predictive, + model=model, + params={ + name: param for name, param in params.items() if name not in guide_sites + }, + num_samples=num_samples, + return_sites=return_sites, + infer_discrete=infer_discrete, + parallel=parallel, + ) + self._batch_shape = (num_samples,) + self.parallel = parallel + self.guide_params = { + name: param for name, param in params.items() if name in guide_sites + } + + self.guide = guide + self.return_sites = return_sites + self.num_mixture_components = jnp.shape(tree_flatten(params)[0][0])[0] + self.mixture_assignment_sitename = mixture_assignment_sitename + + def _call_with_params(self, rng_key, params, args, kwargs): + rng_key, guide_rng_key = random.split(rng_key) + # use return_sites='' as a special signal to return all sites + guide = substitute(self.guide, params) + samples = _predictive( + guide_rng_key, + guide, + {}, + self._batch_shape, + return_sites="", + parallel=self.parallel, + model_args=args, + model_kwargs=kwargs, + ) + return samples + + def __call__(self, rng_key, *args, **kwargs): + guide_key, assign_key, model_key = random.split(rng_key, 3) + + samples_guide = vmap( + lambda key, params: self._call_with_params( + key, params=params, args=args, kwargs=kwargs + ), + in_axes=0, + out_axes=1, + )(random.split(guide_key, self.num_mixture_components), self.guide_params) + + assigns = random.randint( + assign_key, + shape=self._batch_shape, + minval=0, + maxval=self.num_mixture_components, + ) + predictive_assign = tree_map( + lambda arr: vmap(lambda i, assign: arr[i, assign])( + jnp.arange(self._batch_shape[0]), assigns + ), + samples_guide, + ) + predictive_model = self.model_predictive(posterior_samples=predictive_assign) + samples_model = predictive_model(model_key, *args, **kwargs) + if self.return_sites is not None: + samples_guide = { + name: value + for name, value in samples_guide.items() + if name in self.return_sites + } + else: + samples_guide = {} + + return { + self.mixture_assignment_sitename: assigns, + **samples_guide, + **samples_model, # use samples from model if site in model and guide + } diff --git a/numpyro/contrib/einstein/kernels.py b/numpyro/contrib/einstein/stein_kernels.py similarity index 78% rename from numpyro/contrib/einstein/kernels.py rename to numpyro/contrib/einstein/stein_kernels.py index 0600b29a9..983c7d778 100644 --- a/numpyro/contrib/einstein/kernels.py +++ b/numpyro/contrib/einstein/stein_kernels.py @@ -12,20 +12,9 @@ import jax.scipy.linalg import jax.scipy.stats -from numpyro.contrib.einstein.util import median_bandwidth, sqrth_and_inv_sqrth -import numpyro.distributions as dist - - -class PrecondMatrix(ABC): - @abstractmethod - def compute(self, particles: jnp.ndarray, loss_fn: Callable[[jnp.ndarray], float]): - """ - Computes a preconditioning matrix for a given set of particles and a loss function - - :param particles: The Stein particles to compute the preconditioning matrix from - :param loss_fn: Loss function given particles - """ - raise NotImplementedError +from numpyro.contrib.einstein.stein_util import median_bandwidth +from numpyro.distributions import biject_to +from numpyro.infer.autoguide import AutoNormal class SteinKernel(ABC): @@ -320,89 +309,6 @@ def init(self, rng_key, particles_shape): kf.init(krng_key, particles_shape) -class HessianPrecondMatrix(PrecondMatrix): - """ - Calculates the constant precondition matrix based on the negative Hessian of the loss from [1]. - - **References:** - - 1. *Stein Variational Gradient Descent with Matrix-Valued Kernels* by Wang, Tang, Bajaj and Liu - """ - - def compute(self, particles, loss_fn): - hessian = -jax.vmap(jax.hessian(loss_fn))(particles) - return hessian - - -class PrecondMatrixKernel(SteinKernel): - """ - Calculates the const preconditioned kernel - :math:`k(x,y) = Q^{-\\frac{1}{2}}k(Q^{\\frac{1}{2}}x, Q^{\\frac{1}{2}}y)Q^{-\\frac{1}{2}},` - or anchor point preconditioned kernel - :math:`k(x,y) = \\sum_{l=1}^m k_{Q_l}(x,y)w_l(x)w_l(y)` - both from [1]. - - **References:** - - 1. "Stein Variational Gradient Descent with Matrix-Valued Kernels" by Wang, Tang, Bajaj and Liu - - :param precond_matrix_fn: The constant preconditioning matrix - :param inner_kernel_fn: The inner kernel function - :param precond_mode: How to use the precondition matrix, either constant ('const') - or as mixture with anchor points ('anchor_points') - """ - - def __init__( - self, - precond_matrix_fn: PrecondMatrix, - inner_kernel_fn: SteinKernel, - precond_mode="anchor_points", - ): - assert inner_kernel_fn.mode == "matrix" - assert precond_mode == "const" or precond_mode == "anchor_points" - self.precond_matrix_fn = precond_matrix_fn - self.inner_kernel_fn = inner_kernel_fn - self.precond_mode = precond_mode - - @property - def mode(self): - return "matrix" - - def compute(self, particles, particle_info, loss_fn): - qs = self.precond_matrix_fn.compute(particles, loss_fn) - if self.precond_mode == "const": - qs = jnp.expand_dims(jnp.mean(qs, axis=0), axis=0) - qs_inv = jnp.linalg.inv(qs) - qs_sqrt, qs_inv, qs_inv_sqrt = sqrth_and_inv_sqrth(qs) - inner_kernel = self.inner_kernel_fn.compute(particles, particle_info, loss_fn) - - def kernel(x, y): - if self.precond_mode == "const": - wxs = jnp.array([1.0]) - wys = jnp.array([1.0]) - else: - wxs = jax.nn.softmax( - jax.vmap( - lambda z, q_inv: dist.MultivariateNormal(z, q_inv).log_prob(x) - )(particles, qs_inv) - ) - wys = jax.nn.softmax( - jax.vmap( - lambda z, q_inv: dist.MultivariateNormal(z, q_inv).log_prob(y) - )(particles, qs_inv) - ) - return jnp.sum( - jax.vmap( - lambda qs, qis, wx, wy: wx - * wy - * (qis @ inner_kernel(qs @ x, qs @ y) @ qis.transpose()) - )(qs_sqrt, qs_inv_sqrt, wxs, wys), - axis=0, - ) - - return kernel - - class GraphicalKernel(SteinKernel): """ Calculates graphical kernel :math:`k(x,y) = diag({K_l(x_l,y_l)})` for local kernels @@ -467,3 +373,64 @@ def kernel(x, y): return jax.scipy.linalg.block_diag(*kernel_res) return kernel + + +class ProbabilityProductKernel(SteinKernel): + def __init__(self, guide, scale=1.0): + self._mode = "norm" + self.guide = guide + self.scale = scale + assert isinstance(guide, AutoNormal), "PPK only implemented for AutoNormal" + + def compute( + self, + particles: jnp.ndarray, + particle_info: Dict[str, Tuple[int, int]], + loss_fn: Callable[[jnp.ndarray], float], + ): + loc_idx = jnp.concatenate( + [ + jnp.arange(*idx) + for name, idx in particle_info.items() + if name.endswith(f"{self.guide.prefix}_loc") + ] + ) + scale_idx = jnp.concatenate( + [ + jnp.arange(*idx) + for name, idx in particle_info.items() + if name.endswith(f"{self.guide.prefix}_scale") + ] + ) + + def kernel(x, y): + biject = biject_to(self.guide.scale_constraint) + x_loc = x[loc_idx] + x_scale = biject(x[scale_idx]) + x_quad = (x_loc / x_scale) ** 2 + + y_loc = y[loc_idx] + y_scale = biject(y[scale_idx]) + y_quad = (y_loc / y_scale) ** 2 + + cross_loc = x_loc * x_scale**-2 + y_loc * y_scale**-2 + cross_var = 1 / (y_scale**-2 + x_scale**-2) + cross_quad = cross_loc**2 * cross_var + + quad = jnp.exp(-self.scale / 2 * (x_quad + y_quad - cross_quad)) + + norm = ( + (2 * jnp.pi) ** ((1 - 2 * self.scale) * 1 / 2) + * self.scale ** (-1 / 2) + * cross_var ** (1 / 2) + * x_scale ** (-self.scale) + * y_scale ** (-self.scale) + ) + + return jnp.linalg.norm(norm * quad) + + return kernel + + @property + def mode(self): + return self._mode diff --git a/numpyro/contrib/einstein/stein_loss.py b/numpyro/contrib/einstein/stein_loss.py new file mode 100644 index 000000000..31f69424e --- /dev/null +++ b/numpyro/contrib/einstein/stein_loss.py @@ -0,0 +1,101 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from jax import numpy as jnp, random, vmap +from jax.nn import logsumexp + +from numpyro.contrib.einstein.stein_util import batch_ravel_pytree +from numpyro.handlers import replay, seed +from numpyro.infer.util import log_density +from numpyro.util import _validate_model, check_model_guide_match + + +class SteinLoss: + def __init__(self, elbo_num_particles=1, stein_num_particles=1): + self.elbo_num_particles = elbo_num_particles + self.stein_num_particles = stein_num_particles + + def single_particle_loss( + self, + rng_key, + model, + guide, + selected_particle, + unravel_pytree, + flat_particles, + select_index, + model_args, + model_kwargs, + param_map, + ): + guide_key, model_key = random.split(rng_key, 2) + + # 2. Draw from selected mixture component + guide_keys = random.split(guide_key, self.stein_num_particles) + + seeded_chosen = seed(guide, guide_keys[select_index]) + log_chosen_density, chosen_trace = log_density( + seeded_chosen, model_args, model_kwargs, {**param_map, **selected_particle} + ) + + # 3. Score mixture guide + def log_component_density(i): + log_cdensity, component_trace = log_density( + replay(seed(guide, guide_key[i]), chosen_trace), + model_args, + model_kwargs, + {**param_map, **unravel_pytree(flat_particles[i])}, + ) + # Validate + check_model_guide_match(component_trace, chosen_trace) + return log_cdensity + + log_guide_density = logsumexp( + vmap(log_component_density)(jnp.arange(self.stein_num_particles)) + ) + + # 4. Score model + seeded_model = seed(model, model_key) + log_model_density, model_trace = log_density( + replay(seeded_model, chosen_trace), + model_args, + model_kwargs, + {**param_map, **selected_particle}, + ) + + # Validation + check_model_guide_match(model_trace, chosen_trace) + _validate_model(model_trace, plate_warning="loose") + elbo = log_model_density - log_guide_density + return elbo + + def loss(self, rng_key, param_map, model, guide, particles, *args, **kwargs): + if not particles: + raise ValueError("Stein mixture undefined for empty guide.") + + flat_particles, unravel_pytree, _ = batch_ravel_pytree(particles, nbatch_dims=1) + + select_key, score_key = random.split(rng_key) + assigns = random.randint( + select_key, + (self.elbo_num_particles,), + minval=0, + maxval=self.stein_num_particles, + ) + score_keys = random.split(score_key, self.elbo_num_particles) + elbos = vmap( + lambda key, assign: self.single_particle_loss( + rng_key=key, + model=model, + guide=guide, + selected_particle=unravel_pytree(flat_particles[assign]), + unravel_pytree=unravel_pytree, + flat_particles=flat_particles, + select_index=assign, + model_args=args, + model_kwargs=kwargs, + param_map=param_map, + ) + - jnp.log(self.stein_num_particles) + )(score_keys, assigns) + return -elbos.mean() diff --git a/numpyro/contrib/einstein/util.py b/numpyro/contrib/einstein/stein_util.py similarity index 82% rename from numpyro/contrib/einstein/util.py rename to numpyro/contrib/einstein/stein_util.py index ae0d71548..e8cb80372 100644 --- a/numpyro/contrib/einstein/util.py +++ b/numpyro/contrib/einstein/stein_util.py @@ -25,21 +25,6 @@ def sqrth(m): return msqrt -def sqrth_and_inv_sqrth(m): - """ - Given a positive definite matrix, get its Hermitian square root, its inverse, - and the Hermitian square root of its inverse. - """ - mlambda, mvec = jnp.linalg.eigh(m) - mvec_t = jnp.swapaxes(mvec, -2, -1) - mlambdasqrt = jnp.maximum(mlambda, 1e-5) ** 0.5 - msqrt = (mvec * jnp.expand_dims(mlambdasqrt, -2)) @ mvec_t - mlambdasqrt_inv = jnp.maximum(1 / mlambdasqrt, 1e-5**0.5) - minv_sqrt = (mvec * jnp.expand_dims(mlambdasqrt_inv, -2)) @ mvec_t - minv = minv_sqrt @ jnp.swapaxes(minv_sqrt, -2, -1) - return msqrt, minv, minv_sqrt - - def all_pairs_eucl_dist(a, b): a_sqr = jnp.sum(a**2, 1)[None, :] b_sqr = jnp.sum(b**2, 1)[:, None] @@ -48,6 +33,8 @@ def all_pairs_eucl_dist(a, b): def median_bandwidth(particles, factor_fn): + if particles.shape[0] == 1: + return 1.0 # Median produces NaN for single particle dists = all_pairs_eucl_dist(particles, particles) bandwidth = ( jnp.median(dists) ** 2 * factor_fn(particles.shape[0]) diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py index 07a920578..8d9701d40 100644 --- a/numpyro/contrib/einstein/steinvi.py +++ b/numpyro/contrib/einstein/steinvi.py @@ -8,14 +8,16 @@ import operator from typing import Callable -import jax -import jax.numpy as jnp -import jax.random +from jax import grad, jacfwd, numpy as jnp, random, vmap from jax.tree_util import tree_map from numpyro import handlers -from numpyro.contrib.einstein.kernels import SteinKernel -from numpyro.contrib.einstein.util import batch_ravel_pytree, get_parameter_transform +from numpyro.contrib.einstein.stein_kernels import SteinKernel +from numpyro.contrib.einstein.stein_loss import SteinLoss +from numpyro.contrib.einstein.stein_util import ( + batch_ravel_pytree, + get_parameter_transform, +) from numpyro.contrib.funsor import config_enumerate, enum from numpyro.distributions import Distribution, Normal from numpyro.distributions.constraints import real @@ -33,19 +35,19 @@ def _numel(shape): class SteinVI: - """Stein variational inference for stein mixtures. + """Variational inference with stein mixtures. :param model: Python callable with Pyro primitives for the model. :param guide: Python callable with Pyro primitives for the guide (recognition network). :param optim: an instance of :class:`~numpyro.optim._NumpyroOptim`. - :param loss: ELBO loss, i.e. negative Evidence Lower Bound, to minimize. :param kernel_fn: Function that produces a logarithm of the statistical kernel to use with Stein inference - :param num_particles: number of particles for Stein inference. - (More particles capture more of the posterior distribution) + :param num_stein_particles: number of particles for Stein inference. + (More particles give more mixture components and therefore likely capture more of the posterior distribution) + :param num_elbo_particles: number of particles for to approximate the attractive force gradient. + (More particles give better gradient approximations) :param loss_temperature: scaling of loss factor :param repulsion_temperature: scaling of repulsive forces (Non-linear Stein) - :param enum: whether to apply automatic marginalization of discrete variables :param classic_guide_param_fn: predicate on names of parameters in guide which should be optimized classically without Stein (E.g. parameters for large normal networks or other transformation) :param static_kwargs: Static keyword arguments for the model / guide, i.e. arguments @@ -57,9 +59,9 @@ def __init__( model, guide, optim, - loss, kernel_fn: SteinKernel, - num_particles: int = 10, + num_stein_particles: int = 10, + num_elbo_particles: int = 10, loss_temperature: float = 1.0, repulsion_temperature: float = 1.0, classic_guide_params_fn: Callable[[str], bool] = lambda name: False, @@ -70,14 +72,17 @@ def __init__( self.model = model self.guide = guide self.optim = optim - self.loss = loss + self.stein_loss = SteinLoss( # TODO: @OlaRonning handle enum + elbo_num_particles=num_elbo_particles, + stein_num_particles=num_stein_particles, + ) self.kernel_fn = kernel_fn self.static_kwargs = static_kwargs - self.num_particles = num_particles + self.num_particles = num_stein_particles self.loss_temperature = loss_temperature self.repulsion_temperature = repulsion_temperature self.enum = enum - self.classic_guide_params_fn = classic_guide_params_fn + self.model_params_fn = classic_guide_params_fn self.guide_param_names = None self.constrain_fn = None self.uconstrain_fn = None @@ -92,15 +97,15 @@ def _apply_kernel(self, kernel, x, y, v): def _kernel_grad(self, kernel, x, y): if self.kernel_fn.mode == "norm": - return jax.grad(lambda x: kernel(x, y))(x) + return grad(lambda x: kernel(x, y))(x) elif self.kernel_fn.mode == "vector": - return jax.vmap(lambda i: jax.grad(lambda x: kernel(x, y)[i])(x)[i])( + return vmap(lambda i: grad(lambda x: kernel(x, y)[i])(x)[i])( jnp.arange(x.shape[0]) ) else: - return jax.vmap( + return vmap( lambda a: jnp.sum( - jax.vmap(lambda b: jax.grad(lambda x: kernel(x, y)[a, b])(x)[b])( + vmap(lambda b: grad(lambda x: kernel(x, y)[a, b])(x)[b])( jnp.arange(x.shape[0]) ) ) @@ -139,7 +144,7 @@ def extract_info(site): isinstance(inner_guide, AutoGuide) and "_".join((inner_guide.prefix, "loc")) in name ): - site_key, particle_seed = jax.random.split(particle_seed) + site_key, particle_seed = random.split(particle_seed) unconstrained_shape = transform.inverse_shape(value.shape) init_value = jnp.expand_dims( transform.inv(value), 0 @@ -153,14 +158,14 @@ def extract_info(site): else: site_fn = site["fn"] site_args = site["args"] - site_key, particle_seed = jax.random.split(particle_seed) + site_key, particle_seed = random.split(particle_seed) def _reinit(seed): with handlers.seed(rng_seed=seed): return site_fn(*site_args) - init_value = jax.vmap(_reinit)( - jax.random.split(particle_seed, self.num_particles) + init_value = vmap(_reinit)( + random.split(particle_seed, self.num_particles) ) return init_value, constraint @@ -173,13 +178,13 @@ def _reinit(seed): def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs): # 0. Separate model and guide parameters, since only guide parameters are updated using Stein - classic_uparams = { + model_uparams = { p: v for p, v in unconstr_params.items() - if p not in self.guide_param_names or self.classic_guide_params_fn(p) + if p not in self.guide_param_names or self.model_params_fn(p) } stein_uparams = { - p: v for p, v in unconstr_params.items() if p not in classic_uparams + p: v for p, v in unconstr_params.items() if p not in model_uparams } # 1. Collect each guide parameter into monolithic particles that capture correlations # between parameter values across each individual particle @@ -189,67 +194,84 @@ def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs): particle_info, _ = self._calc_particle_info( stein_uparams, stein_particles.shape[0] ) - - # 2. Calculate loss and gradients for each parameter - def scaled_loss(rng_key, classic_params, stein_params): - params = {**classic_params, **stein_params} - loss_val = self.loss.loss( - rng_key, - params, - handlers.scale(self._inference_model, self.loss_temperature), - self.guide, - *args, - **kwargs, - ) - return -loss_val - - def kernel_particle_loss_fn(ps): - return scaled_loss( - rng_key, - self.constrain_fn(classic_uparams), - self.constrain_fn(unravel_pytree(ps)), - ) + attractive_key, classic_key = random.split(rng_key) + + # 2. Calculate gradients for each particle + def kernel_particles_loss_fn(rng_key, particles): + particle_keys = random.split(rng_key, self.stein_loss.stein_num_particles) + grads = vmap( + lambda i: grad( + lambda particle: ( + vmap( + lambda elbo_key: self.stein_loss.single_particle_loss( + rng_key=elbo_key, + model=handlers.scale( + self._inference_model, self.loss_temperature + ), + guide=self.guide, + selected_particle=unravel_pytree(particle), + unravel_pytree=unravel_pytree, + flat_particles=particles, + select_index=i, + model_args=args, + model_kwargs=kwargs, + param_map=self.constrain_fn(model_uparams), + ) + )( + random.split( + particle_keys[i], self.stein_loss.elbo_num_particles + ) + ) + ).mean() + )(particles[i]) + )(jnp.arange(self.stein_loss.stein_num_particles)) + + return grads def particle_transform_fn(particle): params = unravel_pytree(particle) tparams = self.particle_transform_fn(params) + ctparams = self.constrain_fn(tparams) tparticle, _ = ravel_pytree(tparams) - return tparticle + ctparticle, _ = ravel_pytree(ctparams) + return tparticle, ctparticle - tstein_particles = jax.vmap(particle_transform_fn)(stein_particles) + tstein_particles, ctstein_particles = vmap(particle_transform_fn)( + stein_particles + ) - loss, particle_ljp_grads = jax.vmap( - jax.value_and_grad(kernel_particle_loss_fn) - )(tstein_particles) - classic_param_grads = jax.vmap( - lambda ps: jax.grad( - lambda cps: scaled_loss( - rng_key, - self.constrain_fn(cps), - self.constrain_fn(unravel_pytree(ps)), - ) - )(classic_uparams) - )(stein_particles) - classic_param_grads = tree_map(partial(jnp.mean, axis=0), classic_param_grads) + particle_ljp_grads = kernel_particles_loss_fn(attractive_key, ctstein_particles) + + classic_param_grads = grad( + lambda cps: -self.stein_loss.loss( + classic_key, + self.constrain_fn(cps), + handlers.scale(self._inference_model, self.loss_temperature), + self.guide, + unravel_pytree_batched(ctstein_particles), + *args, + **kwargs, + ) + )(model_uparams) # 3. Calculate kernel on monolithic particle - kernel = self.kernel_fn.compute( - stein_particles, particle_info, kernel_particle_loss_fn + kernel = self.kernel_fn.compute( # TODO: Fix to use Stein loss + stein_particles, particle_info, kernel_particles_loss_fn ) # 4. Calculate the attractive force and repulsive force on the monolithic particles - attractive_force = jax.vmap( + attractive_force = vmap( lambda y: jnp.sum( - jax.vmap( + vmap( lambda x, x_ljp_grad: self._apply_kernel(kernel, x, y, x_ljp_grad) )(tstein_particles, particle_ljp_grads), axis=0, ) )(tstein_particles) - repulsive_force = jax.vmap( + repulsive_force = vmap( lambda y: jnp.sum( - jax.vmap( + vmap( lambda x: self.repulsion_temperature * self._kernel_grad(kernel, x, y) )(tstein_particles), @@ -261,7 +283,7 @@ def single_particle_grad(particle, attr_forces, rep_forces): def _nontrivial_jac(var_name, var): if isinstance(self.particle_transforms[var_name], IdentityTransform): return None - return jax.jacfwd(self.particle_transforms[var_name].inv)(var) + return jacfwd(self.particle_transforms[var_name].inv)(var) def _update_force(attr_force, rep_force, jac): force = attr_force.reshape(-1) + rep_force.reshape(-1) @@ -285,7 +307,7 @@ def _update_force(attr_force, rep_force, jac): return jac_particle particle_grads = ( - jax.vmap(single_particle_grad)( + vmap(single_particle_grad)( stein_particles, attractive_force, repulsive_force ) / self.num_particles @@ -296,7 +318,7 @@ def _update_force(attr_force, rep_force, jac): # 6. Return loss and gradients (based on parameter forces) res_grads = tree_map(lambda x: -x, {**classic_param_grads, **stein_param_grads}) - return -jnp.mean(loss), res_grads + return jnp.linalg.norm(particle_grads), res_grads def init(self, rng_key, *args, **kwargs): """ @@ -307,7 +329,7 @@ def init(self, rng_key, *args, **kwargs): during the course of fitting). :return: initial :data:`SteinVIState` """ - rng_key, kernel_seed, model_seed, guide_seed = jax.random.split(rng_key, 4) + rng_key, kernel_seed, model_seed, guide_seed = random.split(rng_key, 4) model_init = handlers.seed(self.model, model_seed) guide_init = handlers.seed(self.guide, guide_seed) guide_trace = handlers.trace(guide_init).get_trace( @@ -316,7 +338,7 @@ def init(self, rng_key, *args, **kwargs): model_trace = handlers.trace(model_init).get_trace( *args, **kwargs, **self.static_kwargs ) - rng_key, particle_seed = jax.random.split(rng_key) + rng_key, particle_seed = random.split(rng_key) guide_init_params = self._find_init_params( particle_seed, self.guide, guide_trace ) @@ -351,7 +373,7 @@ def init(self, rng_key, *args, **kwargs): ) if site["name"] in guide_init_params: pval, _ = guide_init_params[site["name"]] - if self.classic_guide_params_fn(site["name"]): + if self.model_params_fn(site["name"]): pval = tree_map(lambda x: x[0], pval) else: pval = site["value"] @@ -398,7 +420,7 @@ def update(self, state: SteinVIState, *args, **kwargs): during the course of fitting). :return: tuple of `(state, loss)`. """ - rng_key, rng_key_mcmc, rng_key_step = jax.random.split(state.rng_key, num=3) + rng_key, rng_key_mcmc, rng_key_step = random.split(state.rng_key, num=3) params = self.optim.get_params(state.optim_state) optim_state = state.optim_state loss_val, grads = self._svgd_loss_and_grads( @@ -434,6 +456,9 @@ def bodyfn(_i, info): progbar=progress_bar, transform=collect_fn, return_last_val=True, + diagnostics_fn=lambda state: f"norm Stein force: {state[1]:.3f}" + if progress_bar + else None, ) state = last_res[0] return SteinVIRunResult(self.get_params(state), state, auxiliaries) @@ -445,12 +470,12 @@ def evaluate(self, state, *args, **kwargs): :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide. - :return: evaluate loss given the current parameter values (held within `state.optim_state`). + :return: normed stein force given the current parameter values (held within `state.optim_state`). """ # we split to have the same seed as `update_fn` given a state - _, _, rng_key_eval = jax.random.split(state.rng_key, num=3) + _, _, rng_key_eval = random.split(state.rng_key, num=3) params = self.optim.get_params(state.optim_state) - loss_val, _ = self._svgd_loss_and_grads( + normed_stein_force, _ = self._svgd_loss_and_grads( rng_key_eval, params, *args, **kwargs, **self.static_kwargs ) - return loss_val + return normed_stein_force diff --git a/test/contrib/einstein/test_mixture_guide_predictive.py b/test/contrib/einstein/test_mixture_guide_predictive.py new file mode 100644 index 000000000..e14a79631 --- /dev/null +++ b/test/contrib/einstein/test_mixture_guide_predictive.py @@ -0,0 +1,57 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from numpy.testing import assert_allclose + +from jax import random +import jax.numpy as jnp + +import numpyro +from numpyro.contrib.einstein.mixture_guide_predictive import MixtureGuidePredictive +import numpyro.distributions as dist +from numpyro.distributions import constraints + + +def test_predictive_with_particles(): + num_samples = 20 + fdim = 3 + num_data = 10 + mixture_assignment_sitename = "assigns" + num_particles = 3 + + def model(x, y=None): + latent = numpyro.sample("latent", dist.Normal(0.0, jnp.ones(fdim)).to_event(1)) + with numpyro.plate("data", x.shape[0]): + numpyro.sample("y", dist.Normal(x * latent, 1.0).to_event(1), obs=y) + + def guide(x, y=None): + latent_loc = numpyro.param( + "latent_loc", jnp.ones(fdim), constraint=constraints.real + ) + assert latent_loc.ndim == 1 + numpyro.sample("latent", dist.Normal(latent_loc, 0.1).to_event(1)) + + params = jnp.array([[-100, -100, -100.0], [0, 0, 0], [100, 100, 100]]) + x = dist.Normal(jnp.full(fdim, 10), 1.0).sample(random.PRNGKey(0), (num_data,)) + + predictions = MixtureGuidePredictive( + model, + guide=guide, + params={"latent_loc": params}, + num_samples=num_samples, + guide_sites=["latent_loc"], + mixture_assignment_sitename=mixture_assignment_sitename, + )(random.PRNGKey(0), x) + assert predictions["y"].shape == (num_samples, num_data, fdim) + assert mixture_assignment_sitename in predictions + assert jnp.max(predictions[mixture_assignment_sitename]) <= num_particles - 1 + assert 0 <= jnp.min(predictions[mixture_assignment_sitename]) + + # Check we can recover assignments from predictions + pred_assigns = jnp.argmin( + jnp.linalg.norm(predictions["y"][:, :, None] - params, axis=-1), axis=-1 + ) + actual_assigns = jnp.repeat( + predictions[mixture_assignment_sitename][:, None], num_data, 1 + ) + assert_allclose(pred_assigns, actual_assigns) diff --git a/test/contrib/einstein/test_einstein_kernels.py b/test/contrib/einstein/test_stein_kernels.py similarity index 83% rename from test/contrib/einstein/test_einstein_kernels.py rename to test/contrib/einstein/test_stein_kernels.py index 94008783a..d87af2e5b 100644 --- a/test/contrib/einstein/test_einstein_kernels.py +++ b/test/contrib/einstein/test_stein_kernels.py @@ -11,17 +11,14 @@ from jax import numpy as jnp, random from numpyro.contrib.einstein import SteinVI -from numpyro.contrib.einstein.kernels import ( +from numpyro.contrib.einstein.stein_kernels import ( GraphicalKernel, - HessianPrecondMatrix, IMQKernel, LinearKernel, MixtureKernel, - PrecondMatrixKernel, RandomFeatureKernel, RBFKernel, ) -from numpyro.infer import Trace_ELBO from numpyro.optim import Adam T = namedtuple("TestSteinKernel", ["kernel", "particle_info", "loss_fn", "kval"]) @@ -66,18 +63,6 @@ lambda x: x, {"matrix": np.array([[0.040711474, 0.0], [0.0, 0.040711474]])}, ), - T( - lambda mode: PrecondMatrixKernel( - HessianPrecondMatrix(), RBFKernel(mode="matrix"), precond_mode="const" - ), - lambda d: {}, - lambda x: -0.02 / 12 * x[0] ** 4 - 0.5 / 12 * x[1] ** 4 - x[0] * x[1], - { - "matrix": np.array( - [[2.3780507e-04, -1.6688075e-05], [-1.6688075e-05, 1.2849815e-05]] - ) - }, - ), # -hess = [[.02x_0^2 1] [1 .5x_1^2]] ] PARTICLES = [(PARTICLES_2D, TPARTICLES_2D)] @@ -118,7 +103,7 @@ def test_apply_kernel( kernel_fn.init(random.PRNGKey(0), particles.shape) kernel_fn = kernel_fn.compute(particles, particle_info(d), loss_fn) v = np.ones_like(kval[mode]) - stein = SteinVI(id, id, Adam(1.0), Trace_ELBO(), kernel(mode)) + stein = SteinVI(id, id, Adam(1.0), kernel(mode)) value = stein._apply_kernel(kernel_fn, *tparticles, v) kval_ = copy(kval) if mode == "matrix": diff --git a/test/contrib/einstein/test_stein_loss.py b/test/contrib/einstein/test_stein_loss.py new file mode 100644 index 000000000..2f220c752 --- /dev/null +++ b/test/contrib/einstein/test_stein_loss.py @@ -0,0 +1,97 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from numpy.testing import assert_allclose +from pytest import fail + +from jax import numpy as jnp, random, value_and_grad + +import numpyro +from numpyro.contrib.einstein.stein_loss import SteinLoss +from numpyro.contrib.einstein.stein_util import batch_ravel_pytree +import numpyro.distributions as dist +from numpyro.infer import Trace_ELBO + + +def test_single_particle_loss(): + def model(x): + numpyro.sample("obs", dist.Normal(0, 1), obs=x) + + def guide(x): + pass + + try: + SteinLoss(elbo_num_particles=10, stein_num_particles=1).loss( + random.PRNGKey(0), {}, model, guide, {}, 2.0 + ) + fail() + except ValueError: + pass + + +def test_stein_elbo(): + def model(x): + numpyro.sample("x", dist.Normal(0, 1)) + numpyro.sample("obs", dist.Normal(0, 1), obs=x) + + def guide(x): + numpyro.sample("x", dist.Normal(0, 1)) + + def elbo_loss_fn(x, param): + return Trace_ELBO(num_particles=1).loss( + random.PRNGKey(0), param, model, guide, x + ) + + def stein_loss_fn(x, particles): + return SteinLoss(elbo_num_particles=1, stein_num_particles=1).loss( + random.PRNGKey(0), {}, model, guide, particles, x + ) + + elbo_loss, elbo_grad = value_and_grad(elbo_loss_fn)(2.0, {"x": 1.0}) + stein_loss, stein_grad = value_and_grad(stein_loss_fn)(2.0, {"x": jnp.array([1.0])}) + assert_allclose(elbo_loss, stein_loss, rtol=1e-6) + assert_allclose(elbo_grad, stein_grad, rtol=1e-6) + + +def test_stein_particle_loss(): + def model(x): + numpyro.sample("x", dist.Normal(0, 1)) + numpyro.sample("obs", dist.Normal(0, 1), obs=x) + + def guide(x): + numpyro.sample("x", dist.Normal(0, 1)) + + def stein_loss_fn(x, particles, chosen_particle, assign): + return SteinLoss( + elbo_num_particles=1, stein_num_particles=3 + ).single_particle_loss( + random.PRNGKey(0), + model, + guide, + chosen_particle, + unravel_pytree, + particles, + assign, + (x,), + {}, + {}, + ) + + xs = jnp.array([-1, 0.5, 3.0]) + num_particles = xs.shape[0] + particles = {"x": xs} + + flat_particles, unravel_pytree, _ = batch_ravel_pytree(particles, nbatch_dims=1) + losses, grads = [], [] + for i in range(num_particles): + chosen_particle = unravel_pytree(flat_particles[i]) + loss, grad = value_and_grad(stein_loss_fn)( + 2.0, flat_particles, chosen_particle, i + ) + losses.append(loss) + grads.append(grad) + + assert jnp.abs(losses[0] - losses[1]) > 0.1 + assert jnp.abs(losses[1] - losses[2]) > 0.1 + assert_allclose(grads[0], grads[1]) + assert_allclose(grads[1], grads[2]) diff --git a/test/contrib/einstein/test_steinvi.py b/test/contrib/einstein/test_steinvi.py index 169d893b7..1c0214fd2 100644 --- a/test/contrib/einstein/test_steinvi.py +++ b/test/contrib/einstein/test_steinvi.py @@ -13,12 +13,15 @@ import numpyro from numpyro import handlers -from numpyro.contrib.einstein import GraphicalKernel, RBFKernel, SteinVI, kernels -from numpyro.contrib.einstein.kernels import ( - HessianPrecondMatrix, - MixtureKernel, - PrecondMatrixKernel, +from numpyro.contrib.einstein import ( + GraphicalKernel, + IMQKernel, + LinearKernel, + RandomFeatureKernel, + RBFKernel, + SteinVI, ) +from numpyro.contrib.einstein.stein_kernels import MixtureKernel import numpyro.distributions as dist from numpyro.distributions import Bernoulli, Normal, Poisson from numpyro.distributions.transforms import AffineTransform @@ -41,11 +44,11 @@ from numpyro.optim import Adam KERNELS = [ - kernels.RBFKernel(), - kernels.LinearKernel(), - kernels.IMQKernel(), - kernels.GraphicalKernel(), - kernels.RandomFeatureKernel(), + RBFKernel(), + LinearKernel(), + IMQKernel(), + GraphicalKernel(), + RandomFeatureKernel(), ] np.set_printoptions(precision=100) @@ -57,13 +60,6 @@ def __init__(self, mode): super().__init__(mode=mode, local_kernel_fns={"p1": RBFKernel("norm")}) -class WrappedPrecondMatrixKernel(PrecondMatrixKernel): - def __init__(self, mode): - super().__init__( - HessianPrecondMatrix(), RBFKernel(mode=mode), precond_mode="const" - ) - - class WrappedMixtureKernel(MixtureKernel): def __init__(self, mode): super().__init__( @@ -130,7 +126,6 @@ def test_steinvi_smoke(kernel, auto_guide, init_loc_fn, problem): model, auto_guide(model, init_loc_fn=init_loc_fn), Adam(1e-1), - Trace_ELBO(), kernel, ) stein.run(random.PRNGKey(0), 1, *data) @@ -156,7 +151,7 @@ def test_get_params(kernel, auto_guide, init_loc_fn, problem): Trace_ELBO(), ) - stein = SteinVI(model, guide, optim, elbo, kernel) + stein = SteinVI(model, guide, optim, kernel) stein_params = stein.get_params(stein.init(random.PRNGKey(0), *data)) svi = SVI(model, guide, optim, elbo) @@ -211,9 +206,8 @@ def model(obs): model, auto_class(model, init_loc_fn=init_loc_fn()), Adam(1.0), - Trace_ELBO(), RBFKernel(), - num_particles=num_particles, + num_stein_particles=num_particles, ) state = steinvi.init(stein_key, obs) init_params = steinvi.get_params(state) @@ -234,35 +228,6 @@ def model(obs): assert_array_approx_equal(init_value, np.full(expected_shape, 0.0)) -def test_svgd_loss_and_grads(): - true_coefs, data, model = uniform_normal() - guide = AutoDelta(model) - loss = Trace_ELBO() - stein_uparams = { - "alpha_auto_loc": np.array( - [ - -1.2, - ] - ), - "loc_base_auto_loc": np.array( - [ - 1.53, - ] - ), - } - stein = SteinVI(model, guide, Adam(0.1), loss, RBFKernel()) - stein.init(random.PRNGKey(0), *data) - svi = SVI(model, guide, Adam(0.1), loss) - svi.init(random.PRNGKey(0), *data) - expected_loss = loss.loss( - random.PRNGKey(1), svi.constrain_fn(stein_uparams), model, guide, *data - ) - stein_loss, stein_grad = stein._svgd_loss_and_grads( - random.PRNGKey(1), stein_uparams, *data - ) - assert expected_loss == stein_loss - - @pytest.mark.parametrize("length", [1, 2, 3, 6]) @pytest.mark.parametrize("depth", [1, 3, 5]) @pytest.mark.parametrize("t", [list, tuple]) # add dict, set @@ -276,7 +241,7 @@ def nest(v, d): sizes = Poisson(5).sample(seed, (length, nrandom.randint(0, 10))) + 1 total_size = sum(map(lambda size: size.prod(), sizes)) uparam = t(nest(np.empty(tuple(size)), nrandom.randint(0, depth)) for size in sizes) - stein = SteinVI(id, id, Adam(1.0), Trace_ELBO(), RBFKernel()) + stein = SteinVI(id, id, Adam(1.0), RBFKernel()) assert stein._param_size(uparam) == total_size, f"Failed for seed {seed}" @@ -316,7 +281,7 @@ def test_calc_particle_info_nested(): for i in range(num_params) } - stein = SteinVI(id, id, Adam(1.0), Trace_ELBO(), RBFKernel()) + stein = SteinVI(id, id, Adam(1.0), RBFKernel()) pinfo, _ = stein._calc_particle_info(uparams, num_particles) start = 0 tot_size = sum(map(lambda size: size.prod(), sizes)) // num_particles diff --git a/test/contrib/einstein/test_einstein_util.py b/test/contrib/einstein/test_steinvi_util.py similarity index 73% rename from test/contrib/einstein/test_einstein_util.py rename to test/contrib/einstein/test_steinvi_util.py index bb38993dc..532f7e18c 100644 --- a/test/contrib/einstein/test_einstein_util.py +++ b/test/contrib/einstein/test_steinvi_util.py @@ -11,12 +11,7 @@ from jax import numpy as jnp from jax.tree_util import tree_flatten, tree_map -from numpyro.contrib.einstein.util import ( - batch_ravel_pytree, - posdef, - sqrth, - sqrth_and_inv_sqrth, -) +from numpyro.contrib.einstein.stein_util import batch_ravel_pytree, posdef, sqrth pd_matrices = [ np.array( @@ -68,25 +63,6 @@ def test_sqrth_shape(batch_shape): assert_allclose(s @ np.swapaxes(s, -2, -1), m, rtol=1e-5) -@pytest.mark.parametrize("m", pd_matrices) -def test_sqrt_inv_sqrth(m): - msqrt, minv, minv_sqrt = sqrth_and_inv_sqrth(m) - assert_allclose(msqrt, scipy.linalg.sqrtm(m), atol=1e-5) - assert_allclose(minv, np.linalg.inv(m), atol=1e-4) - assert_allclose(minv_sqrt, np.linalg.inv(scipy.linalg.sqrtm(m)), atol=1e-5) - - -@pytest.mark.parametrize("batch_shape", [(), (2,), (3, 1)]) -def test_sqrth_and_inv_sqrth_shape(batch_shape): - dim = 4 - x = np.random.normal(size=batch_shape + (dim, dim + 1)) - m = x @ np.swapaxes(x, -2, -1) - s, i, si = sqrth_and_inv_sqrth(m) - assert_allclose(s @ np.swapaxes(s, -2, -1), m, rtol=1e-5) - assert_allclose(i, np.linalg.inv(m), rtol=1e-5) - assert_allclose(si @ np.swapaxes(si, -2, -1), i, rtol=1e-5) - - @pytest.mark.parametrize( "pytree", [ diff --git a/test/test_examples.py b/test/test_examples.py index 388793bc8..ebb034af9 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -65,8 +65,8 @@ ), "sparse_regression.py --num-samples 10 --num-warmup 10 --num-data 10 --num-dimensions 10", "ssbvm_mixture.py --num-samples 10 --num-warmup 10", - "stein_bnn.py --max-iter 10 --subsample-size 10 --num-particles 5", - "stein_dmm.py --max-iter 5 --subsample-size 77 --gru-dim 10", + "stein_dmm.py --max-iter 5 --subsample-size 5 --gru-dim 10", + "stein_bnn.py --max-iter 10 --subsample-size 10", "stochastic_volatility.py --num-samples 100 --num-warmup 100", "toy_mixture_model_discrete_enumeration.py --num-steps=1", "ucbadmit.py --num-chains 2",