Skip to content

Commit

Permalink
Stein mixture (#1601)
Browse files Browse the repository at this point in the history
* rewrote stein mixture and added mixture guide predictive.

* added test examples

* added mixture guide predictive test.

* fixed sample ordering in `mixture_guide_predictive.py` so model samples are returned when sites are in guide and model.
  • Loading branch information
OlaRonning authored Jun 10, 2023
1 parent 523162f commit 0b9e0f0
Show file tree
Hide file tree
Showing 15 changed files with 609 additions and 334 deletions.
21 changes: 10 additions & 11 deletions docs/source/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ SVGD is well suited for capturing correlations between latent variables as a par
The technique preserves the scalability of traditional VI approaches while offering the flexibility and modeling scope
of methods such as Markov chain Monte Carlo (MCMC). SVGD is good at capturing multi-modality [3][4].

``numpyro.contrib.einstein`` is a framework for particle-based inference using the ELBO-within-Stein algorithm.
``numpyro.contrib.einstein`` is a framework for particle-based inference using the Stein mixture algorithm.
The framework works on Stein mixtures, a restricted mixture of guide programs parameterized by Stein particles.
Similarly to how SVGD works, Stein mixtures can approximate model posteriors by moving the Stein particles according
to the Stein forces. Because the Stein particles parameterize a guide, they capture a neighborhood rather than a
single point. This property means Stein mixtures significantly reduce the number of particles needed to represent
high dimensional models.
single point.

``numpyro.contrib.einstein`` mimics the interface from ``numpyro.infer.svi``, so trying SteinVI requires minimal
change to the code for existing models inferred with SVI. For primary usage, see the
Expand All @@ -40,9 +39,8 @@ The framework currently supports several kernels, including:
- `LinearKernel`
- `RandomFeatureKernel`
- `MixtureKernel`
- `PrecondMatrixKernel`
- `HessianPrecondMatrix`
- `GraphicalKernel`
- `ProbabilityProductKernel`

For example, usage see:

Expand All @@ -68,10 +66,11 @@ SteinVI Interface

SteinVI Kernels
---------------
.. autoclass:: numpyro.contrib.einstein.kernels.RBFKernel
.. autoclass:: numpyro.contrib.einstein.kernels.LinearKernel
.. autoclass:: numpyro.contrib.einstein.kernels.RandomFeatureKernel
.. autoclass:: numpyro.contrib.einstein.kernels.MixtureKernel
.. autoclass:: numpyro.contrib.einstein.kernels.PrecondMatrixKernel
.. autoclass:: numpyro.contrib.einstein.kernels.GraphicalKernel
.. autoclass:: numpyro.contrib.einstein.stein_kernels.RBFKernel
.. autoclass:: numpyro.contrib.einstein.stein_kernels.LinearKernel
.. autoclass:: numpyro.contrib.einstein.stein_kernels.RandomFeatureKernel
.. autoclass:: numpyro.contrib.einstein.stein_kernels.MixtureKernel
.. autoclass:: numpyro.contrib.einstein.stein_kernels.GraphicalKernel
.. autoclass:: numpyro.contrib.einstein.stein_kernels.ProbabilityProductKernel


44 changes: 25 additions & 19 deletions examples/stein_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
import jax.numpy as jnp

import numpyro
from numpyro.contrib.einstein import RBFKernel, SteinVI
from numpyro import deterministic
from numpyro.contrib.einstein import IMQKernel, SteinVI
from numpyro.contrib.einstein.mixture_guide_predictive import MixtureGuidePredictive
from numpyro.distributions import Gamma, Normal
from numpyro.examples.datasets import BOSTON_HOUSING, load_dataset
from numpyro.infer import Predictive, Trace_ELBO, init_to_uniform
from numpyro.infer.autoguide import AutoDelta
from numpyro.infer import init_to_uniform
from numpyro.infer.autoguide import AutoNormal
from numpyro.optim import Adagrad

DataState = namedtuple("data", ["xtr", "xte", "ytr", "yte"])
Expand All @@ -36,7 +38,7 @@
def load_data() -> DataState:
_, fetch = load_dataset(BOSTON_HOUSING, shuffle=False)
x, y = fetch()
xtr, xte, ytr, yte = train_test_split(x, y, train_size=0.90)
xtr, xte, ytr, yte = train_test_split(x, y, train_size=0.90, random_state=1)

return DataState(*map(partial(jnp.array, dtype=float), (xtr, xte, ytr, yte)))

Expand Down Expand Up @@ -105,10 +107,12 @@ def model(x, y=None, hidden_dim=50, subsample_size=100):
else:
batch_y = y

loc_y = deterministic("y_pred", jnp.maximum(batch_x @ w1 + b1, 0) @ w2 + b2)

numpyro.sample(
"y",
Normal(
jnp.maximum(batch_x @ w1 + b1, 0) @ w2 + b2, 1.0 / jnp.sqrt(prec_obs)
loc_y, 1.0 / jnp.sqrt(prec_obs)
), # 1 hidden layer with ReLU activation
obs=batch_y,
)
Expand All @@ -124,14 +128,17 @@ def main(args):

rng_key, inf_key = random.split(inf_key)

guide = AutoNormal(model, init_loc_fn=partial(init_to_uniform, radius=0.1))

stein = SteinVI(
model,
AutoDelta(model, init_loc_fn=partial(init_to_uniform, radius=0.1)),
guide,
Adagrad(0.05),
Trace_ELBO(20), # estimate elbo with 20 particles (not stein particles!)
RBFKernel(),
IMQKernel(),
# ProbabilityProductKernel(guide=guide, scale=1.),
repulsion_temperature=args.repulsion,
num_particles=args.num_particles,
num_stein_particles=args.num_stein_particles,
num_elbo_particles=args.num_elbo_particles,
)
start = time()

Expand All @@ -147,33 +154,31 @@ def main(args):
)
time_taken = time() - start

pred = Predictive(
pred = MixtureGuidePredictive(
model,
guide=stein.guide,
params=stein.get_params(result.state),
num_samples=200,
batch_ndims=1, # stein particle dimension
num_samples=100,
guide_sites=stein.guide_param_names,
)
xte, _, _ = normalize(
data.xte, xtr_mean, xtr_std
) # use train data statistics when accessing generalization
preds = pred(
pred_key, xte, subsample_size=xte.shape[0], hidden_dim=args.hidden_dim
)["y"]
)["y_pred"]

y_pred = jnp.mean(preds, 1) * ytr_std + ytr_mean
y_pred = preds * ytr_std + ytr_mean
rmse = jnp.sqrt(jnp.mean((y_pred.mean(0) - data.yte) ** 2))

print(rf"Time taken: {datetime.timedelta(seconds=int(time_taken))}")
print(rf"RMSE: {rmse:.2f}")

# compute mean prediction and confidence interval around median
mean_prediction = jnp.mean(y_pred, 0)
mean_prediction = y_pred.mean(0)

ran = np.arange(mean_prediction.shape[0])
percentiles = np.percentile(
preds.reshape(-1, xte.shape[0]) * ytr_std + ytr_mean, [5.0, 95.0], axis=0
)
percentiles = np.percentile(preds * ytr_std + ytr_mean, [5.0, 95.0], axis=0)

# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
Expand All @@ -199,7 +204,8 @@ def main(args):
parser.add_argument("--max-iter", type=int, default=1000)
parser.add_argument("--repulsion", type=float, default=1.0)
parser.add_argument("--verbose", type=bool, default=True)
parser.add_argument("--num-particles", type=int, default=100)
parser.add_argument("--num-elbo-particles", type=int, default=50)
parser.add_argument("--num-stein-particles", type=int, default=5)
parser.add_argument("--progress-bar", type=bool, default=True)
parser.add_argument("--rng-key", type=int, default=142)
parser.add_argument("--device", default="cpu", choices=["gpu", "cpu"])
Expand Down
24 changes: 13 additions & 11 deletions examples/stein_dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@

import numpyro
from numpyro.contrib.einstein import SteinVI
from numpyro.contrib.einstein.kernels import RBFKernel
from numpyro.contrib.einstein.mixture_guide_predictive import MixtureGuidePredictive
from numpyro.contrib.einstein.stein_kernels import RBFKernel
import numpyro.distributions as dist
from numpyro.examples.datasets import JSB_CHORALES, load_dataset
from numpyro.infer import Predictive, Trace_ELBO
from numpyro.optim import optax_to_numpyro


Expand Down Expand Up @@ -293,17 +293,17 @@ def vis_tune(i, tunes, lengths, name="stein_dmm.pdf"):
def main(args):
inf_key, pred_key = random.split(random.PRNGKey(seed=args.rng_seed), 2)

vi = SteinVI(
steinvi = SteinVI(
model,
guide,
optax_to_numpyro(chain(adam(1e-2))),
Trace_ELBO(),
RBFKernel(),
num_particles=args.num_particles,
num_elbo_particles=args.num_elbo_particles,
num_stein_particles=args.num_stein_particles,
)

seqs, rev_seqs, lengths = load_data()
results = vi.run(
results = steinvi.run(
inf_key,
args.max_iter,
seqs,
Expand All @@ -312,11 +312,12 @@ def main(args):
gru_dim=args.gru_dim,
subsample_size=args.subsample_size,
)
pred = Predictive(
pred = MixtureGuidePredictive(
model,
guide,
params=results.params,
num_samples=1,
batch_ndims=1,
guide_sites=steinvi.guide_param_names,
)
seqs, rev_seqs, lengths = load_data("valid")
pred_notes = pred(
Expand All @@ -328,11 +329,12 @@ def main(args):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--subsample-size", type=int, default=77)
parser.add_argument("--max-iter", type=int, default=1000)
parser.add_argument("--subsample-size", type=int, default=10)
parser.add_argument("--max-iter", type=int, default=100)
parser.add_argument("--repulsion", type=float, default=1.0)
parser.add_argument("--verbose", type=bool, default=True)
parser.add_argument("--num-particles", type=int, default=5)
parser.add_argument("--num-stein-particles", type=int, default=5)
parser.add_argument("--num-elbo-particles", type=int, default=5)
parser.add_argument("--progress-bar", type=bool, default=True)
parser.add_argument("--gru-dim", type=int, default=150)
parser.add_argument("--rng-key", type=int, default=142)
Expand Down
14 changes: 7 additions & 7 deletions numpyro/contrib/einstein/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from numpyro.contrib.einstein.kernels import (
from numpyro.contrib.einstein.stein_kernels import (
GraphicalKernel,
HessianPrecondMatrix,
IMQKernel,
LinearKernel,
PrecondMatrix,
PrecondMatrixKernel,
MixtureKernel,
ProbabilityProductKernel,
RandomFeatureKernel,
RBFKernel,
)
from numpyro.contrib.einstein.stein_loss import SteinLoss
from numpyro.contrib.einstein.steinvi import SteinVI

__all__ = [
"SteinVI",
"SteinLoss",
"RBFKernel",
"PrecondMatrix",
"IMQKernel",
"LinearKernel",
"RandomFeatureKernel",
"HessianPrecondMatrix",
"GraphicalKernel",
"PrecondMatrixKernel",
"MixtureKernel",
"ProbabilityProductKernel",
]
108 changes: 108 additions & 0 deletions numpyro/contrib/einstein/mixture_guide_predictive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from functools import partial
from typing import Callable, Dict, Optional, Sequence

from jax import numpy as jnp, random, tree_map, vmap
from jax.tree_util import tree_flatten

from numpyro.handlers import substitute
from numpyro.infer import Predictive
from numpyro.infer.util import _predictive


class MixtureGuidePredictive:
"""
For single mixture component use numpyro.infer.Predictive.
"""

def __init__(
self,
model: Callable,
guide: Callable,
params: Dict,
guide_sites: Sequence,
num_samples: Optional[int] = None,
return_sites: Optional[Sequence[str]] = None,
infer_discrete: bool = False,
parallel: bool = False,
mixture_assignment_sitename="mixture_assignments",
):
self.model_predictive = partial(
Predictive,
model=model,
params={
name: param for name, param in params.items() if name not in guide_sites
},
num_samples=num_samples,
return_sites=return_sites,
infer_discrete=infer_discrete,
parallel=parallel,
)
self._batch_shape = (num_samples,)
self.parallel = parallel
self.guide_params = {
name: param for name, param in params.items() if name in guide_sites
}

self.guide = guide
self.return_sites = return_sites
self.num_mixture_components = jnp.shape(tree_flatten(params)[0][0])[0]
self.mixture_assignment_sitename = mixture_assignment_sitename

def _call_with_params(self, rng_key, params, args, kwargs):
rng_key, guide_rng_key = random.split(rng_key)
# use return_sites='' as a special signal to return all sites
guide = substitute(self.guide, params)
samples = _predictive(
guide_rng_key,
guide,
{},
self._batch_shape,
return_sites="",
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
)
return samples

def __call__(self, rng_key, *args, **kwargs):
guide_key, assign_key, model_key = random.split(rng_key, 3)

samples_guide = vmap(
lambda key, params: self._call_with_params(
key, params=params, args=args, kwargs=kwargs
),
in_axes=0,
out_axes=1,
)(random.split(guide_key, self.num_mixture_components), self.guide_params)

assigns = random.randint(
assign_key,
shape=self._batch_shape,
minval=0,
maxval=self.num_mixture_components,
)
predictive_assign = tree_map(
lambda arr: vmap(lambda i, assign: arr[i, assign])(
jnp.arange(self._batch_shape[0]), assigns
),
samples_guide,
)
predictive_model = self.model_predictive(posterior_samples=predictive_assign)
samples_model = predictive_model(model_key, *args, **kwargs)
if self.return_sites is not None:
samples_guide = {
name: value
for name, value in samples_guide.items()
if name in self.return_sites
}
else:
samples_guide = {}

return {
self.mixture_assignment_sitename: assigns,
**samples_guide,
**samples_model, # use samples from model if site in model and guide
}
Loading

0 comments on commit 0b9e0f0

Please sign in to comment.