Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stein mixture #1601

Merged
merged 4 commits into from
Jun 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions docs/source/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ SVGD is well suited for capturing correlations between latent variables as a par
The technique preserves the scalability of traditional VI approaches while offering the flexibility and modeling scope
of methods such as Markov chain Monte Carlo (MCMC). SVGD is good at capturing multi-modality [3][4].

``numpyro.contrib.einstein`` is a framework for particle-based inference using the ELBO-within-Stein algorithm.
``numpyro.contrib.einstein`` is a framework for particle-based inference using the Stein mixture algorithm.
The framework works on Stein mixtures, a restricted mixture of guide programs parameterized by Stein particles.
Similarly to how SVGD works, Stein mixtures can approximate model posteriors by moving the Stein particles according
to the Stein forces. Because the Stein particles parameterize a guide, they capture a neighborhood rather than a
single point. This property means Stein mixtures significantly reduce the number of particles needed to represent
high dimensional models.
single point.

``numpyro.contrib.einstein`` mimics the interface from ``numpyro.infer.svi``, so trying SteinVI requires minimal
change to the code for existing models inferred with SVI. For primary usage, see the
Expand All @@ -40,9 +39,8 @@ The framework currently supports several kernels, including:
- `LinearKernel`
- `RandomFeatureKernel`
- `MixtureKernel`
- `PrecondMatrixKernel`
- `HessianPrecondMatrix`
- `GraphicalKernel`
- `ProbabilityProductKernel`

For example, usage see:

Expand All @@ -68,10 +66,11 @@ SteinVI Interface

SteinVI Kernels
---------------
.. autoclass:: numpyro.contrib.einstein.kernels.RBFKernel
.. autoclass:: numpyro.contrib.einstein.kernels.LinearKernel
.. autoclass:: numpyro.contrib.einstein.kernels.RandomFeatureKernel
.. autoclass:: numpyro.contrib.einstein.kernels.MixtureKernel
.. autoclass:: numpyro.contrib.einstein.kernels.PrecondMatrixKernel
.. autoclass:: numpyro.contrib.einstein.kernels.GraphicalKernel
.. autoclass:: numpyro.contrib.einstein.stein_kernels.RBFKernel
.. autoclass:: numpyro.contrib.einstein.stein_kernels.LinearKernel
.. autoclass:: numpyro.contrib.einstein.stein_kernels.RandomFeatureKernel
.. autoclass:: numpyro.contrib.einstein.stein_kernels.MixtureKernel
.. autoclass:: numpyro.contrib.einstein.stein_kernels.GraphicalKernel
.. autoclass:: numpyro.contrib.einstein.stein_kernels.ProbabilityProductKernel


44 changes: 25 additions & 19 deletions examples/stein_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
import jax.numpy as jnp

import numpyro
from numpyro.contrib.einstein import RBFKernel, SteinVI
from numpyro import deterministic
from numpyro.contrib.einstein import IMQKernel, SteinVI
from numpyro.contrib.einstein.mixture_guide_predictive import MixtureGuidePredictive
from numpyro.distributions import Gamma, Normal
from numpyro.examples.datasets import BOSTON_HOUSING, load_dataset
from numpyro.infer import Predictive, Trace_ELBO, init_to_uniform
from numpyro.infer.autoguide import AutoDelta
from numpyro.infer import init_to_uniform
from numpyro.infer.autoguide import AutoNormal
from numpyro.optim import Adagrad

DataState = namedtuple("data", ["xtr", "xte", "ytr", "yte"])
Expand All @@ -36,7 +38,7 @@
def load_data() -> DataState:
_, fetch = load_dataset(BOSTON_HOUSING, shuffle=False)
x, y = fetch()
xtr, xte, ytr, yte = train_test_split(x, y, train_size=0.90)
xtr, xte, ytr, yte = train_test_split(x, y, train_size=0.90, random_state=1)

return DataState(*map(partial(jnp.array, dtype=float), (xtr, xte, ytr, yte)))

Expand Down Expand Up @@ -105,10 +107,12 @@ def model(x, y=None, hidden_dim=50, subsample_size=100):
else:
batch_y = y

loc_y = deterministic("y_pred", jnp.maximum(batch_x @ w1 + b1, 0) @ w2 + b2)

numpyro.sample(
"y",
Normal(
jnp.maximum(batch_x @ w1 + b1, 0) @ w2 + b2, 1.0 / jnp.sqrt(prec_obs)
loc_y, 1.0 / jnp.sqrt(prec_obs)
), # 1 hidden layer with ReLU activation
obs=batch_y,
)
Expand All @@ -124,14 +128,17 @@ def main(args):

rng_key, inf_key = random.split(inf_key)

guide = AutoNormal(model, init_loc_fn=partial(init_to_uniform, radius=0.1))

stein = SteinVI(
model,
AutoDelta(model, init_loc_fn=partial(init_to_uniform, radius=0.1)),
guide,
Adagrad(0.05),
Trace_ELBO(20), # estimate elbo with 20 particles (not stein particles!)
RBFKernel(),
IMQKernel(),
# ProbabilityProductKernel(guide=guide, scale=1.),
repulsion_temperature=args.repulsion,
num_particles=args.num_particles,
num_stein_particles=args.num_stein_particles,
num_elbo_particles=args.num_elbo_particles,
)
start = time()

Expand All @@ -147,33 +154,31 @@ def main(args):
)
time_taken = time() - start

pred = Predictive(
pred = MixtureGuidePredictive(
model,
guide=stein.guide,
params=stein.get_params(result.state),
num_samples=200,
batch_ndims=1, # stein particle dimension
num_samples=100,
guide_sites=stein.guide_param_names,
)
xte, _, _ = normalize(
data.xte, xtr_mean, xtr_std
) # use train data statistics when accessing generalization
preds = pred(
pred_key, xte, subsample_size=xte.shape[0], hidden_dim=args.hidden_dim
)["y"]
)["y_pred"]

y_pred = jnp.mean(preds, 1) * ytr_std + ytr_mean
y_pred = preds * ytr_std + ytr_mean
rmse = jnp.sqrt(jnp.mean((y_pred.mean(0) - data.yte) ** 2))

print(rf"Time taken: {datetime.timedelta(seconds=int(time_taken))}")
print(rf"RMSE: {rmse:.2f}")

# compute mean prediction and confidence interval around median
mean_prediction = jnp.mean(y_pred, 0)
mean_prediction = y_pred.mean(0)

ran = np.arange(mean_prediction.shape[0])
percentiles = np.percentile(
preds.reshape(-1, xte.shape[0]) * ytr_std + ytr_mean, [5.0, 95.0], axis=0
)
percentiles = np.percentile(preds * ytr_std + ytr_mean, [5.0, 95.0], axis=0)

# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
Expand All @@ -199,7 +204,8 @@ def main(args):
parser.add_argument("--max-iter", type=int, default=1000)
parser.add_argument("--repulsion", type=float, default=1.0)
parser.add_argument("--verbose", type=bool, default=True)
parser.add_argument("--num-particles", type=int, default=100)
parser.add_argument("--num-elbo-particles", type=int, default=50)
parser.add_argument("--num-stein-particles", type=int, default=5)
parser.add_argument("--progress-bar", type=bool, default=True)
parser.add_argument("--rng-key", type=int, default=142)
parser.add_argument("--device", default="cpu", choices=["gpu", "cpu"])
Expand Down
24 changes: 13 additions & 11 deletions examples/stein_dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@

import numpyro
from numpyro.contrib.einstein import SteinVI
from numpyro.contrib.einstein.kernels import RBFKernel
from numpyro.contrib.einstein.mixture_guide_predictive import MixtureGuidePredictive
from numpyro.contrib.einstein.stein_kernels import RBFKernel
import numpyro.distributions as dist
from numpyro.examples.datasets import JSB_CHORALES, load_dataset
from numpyro.infer import Predictive, Trace_ELBO
from numpyro.optim import optax_to_numpyro


Expand Down Expand Up @@ -293,17 +293,17 @@ def vis_tune(i, tunes, lengths, name="stein_dmm.pdf"):
def main(args):
inf_key, pred_key = random.split(random.PRNGKey(seed=args.rng_seed), 2)

vi = SteinVI(
steinvi = SteinVI(
model,
guide,
optax_to_numpyro(chain(adam(1e-2))),
Trace_ELBO(),
RBFKernel(),
num_particles=args.num_particles,
num_elbo_particles=args.num_elbo_particles,
num_stein_particles=args.num_stein_particles,
)

seqs, rev_seqs, lengths = load_data()
results = vi.run(
results = steinvi.run(
inf_key,
args.max_iter,
seqs,
Expand All @@ -312,11 +312,12 @@ def main(args):
gru_dim=args.gru_dim,
subsample_size=args.subsample_size,
)
pred = Predictive(
pred = MixtureGuidePredictive(
model,
guide,
params=results.params,
num_samples=1,
batch_ndims=1,
guide_sites=steinvi.guide_param_names,
)
seqs, rev_seqs, lengths = load_data("valid")
pred_notes = pred(
Expand All @@ -328,11 +329,12 @@ def main(args):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--subsample-size", type=int, default=77)
parser.add_argument("--max-iter", type=int, default=1000)
parser.add_argument("--subsample-size", type=int, default=10)
parser.add_argument("--max-iter", type=int, default=100)
parser.add_argument("--repulsion", type=float, default=1.0)
parser.add_argument("--verbose", type=bool, default=True)
parser.add_argument("--num-particles", type=int, default=5)
parser.add_argument("--num-stein-particles", type=int, default=5)
parser.add_argument("--num-elbo-particles", type=int, default=5)
parser.add_argument("--progress-bar", type=bool, default=True)
parser.add_argument("--gru-dim", type=int, default=150)
parser.add_argument("--rng-key", type=int, default=142)
Expand Down
14 changes: 7 additions & 7 deletions numpyro/contrib/einstein/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from numpyro.contrib.einstein.kernels import (
from numpyro.contrib.einstein.stein_kernels import (
GraphicalKernel,
HessianPrecondMatrix,
IMQKernel,
LinearKernel,
PrecondMatrix,
PrecondMatrixKernel,
MixtureKernel,
ProbabilityProductKernel,
RandomFeatureKernel,
RBFKernel,
)
from numpyro.contrib.einstein.stein_loss import SteinLoss
from numpyro.contrib.einstein.steinvi import SteinVI

__all__ = [
"SteinVI",
"SteinLoss",
"RBFKernel",
"PrecondMatrix",
"IMQKernel",
"LinearKernel",
"RandomFeatureKernel",
"HessianPrecondMatrix",
"GraphicalKernel",
"PrecondMatrixKernel",
"MixtureKernel",
"ProbabilityProductKernel",
]
108 changes: 108 additions & 0 deletions numpyro/contrib/einstein/mixture_guide_predictive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from functools import partial
from typing import Callable, Dict, Optional, Sequence

from jax import numpy as jnp, random, tree_map, vmap
from jax.tree_util import tree_flatten

from numpyro.handlers import substitute
from numpyro.infer import Predictive
from numpyro.infer.util import _predictive


class MixtureGuidePredictive:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi I believe we can remove the nbatch_dims=1 from Predictive. Let me know what you think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is useful in some cases, e.g. to do predictive with (num_chains, num_samples) or for autoguide posterior samples with multiple sample shapes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that this is a bit complicated. I think the Predictive class is pretty flexible. Complicated logic to handle mixture stuff should belong to the guide, not the Predictive (correct me if I'm wrong).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, the core is I need the mixture assignments to sample from the model. I specialized the ELBO to SteinLoss account for this.

I suppose the guide is the shared element for both cases (loss and predictive sampling) so hiding the complications there seems more modular. I'll try to lift the guide to a mixture guide (and have it as a SteinVI attribute)--then the lifted guide would work with current Predictive and ELBO. Could also simplify some logic in SteinVI. I'll sketch it and get back.

"""
For single mixture component use numpyro.infer.Predictive.
"""

def __init__(
self,
model: Callable,
guide: Callable,
params: Dict,
guide_sites: Sequence,
num_samples: Optional[int] = None,
return_sites: Optional[Sequence[str]] = None,
infer_discrete: bool = False,
parallel: bool = False,
mixture_assignment_sitename="mixture_assignments",
):
self.model_predictive = partial(
Predictive,
model=model,
params={
name: param for name, param in params.items() if name not in guide_sites
},
num_samples=num_samples,
return_sites=return_sites,
infer_discrete=infer_discrete,
parallel=parallel,
)
self._batch_shape = (num_samples,)
self.parallel = parallel
self.guide_params = {
name: param for name, param in params.items() if name in guide_sites
}

self.guide = guide
self.return_sites = return_sites
self.num_mixture_components = jnp.shape(tree_flatten(params)[0][0])[0]
self.mixture_assignment_sitename = mixture_assignment_sitename

def _call_with_params(self, rng_key, params, args, kwargs):
rng_key, guide_rng_key = random.split(rng_key)
# use return_sites='' as a special signal to return all sites
guide = substitute(self.guide, params)
samples = _predictive(
guide_rng_key,
guide,
{},
self._batch_shape,
return_sites="",
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
)
return samples

def __call__(self, rng_key, *args, **kwargs):
guide_key, assign_key, model_key = random.split(rng_key, 3)

samples_guide = vmap(
lambda key, params: self._call_with_params(
key, params=params, args=args, kwargs=kwargs
),
in_axes=0,
out_axes=1,
)(random.split(guide_key, self.num_mixture_components), self.guide_params)

assigns = random.randint(
assign_key,
shape=self._batch_shape,
minval=0,
maxval=self.num_mixture_components,
)
predictive_assign = tree_map(
lambda arr: vmap(lambda i, assign: arr[i, assign])(
jnp.arange(self._batch_shape[0]), assigns
),
samples_guide,
)
predictive_model = self.model_predictive(posterior_samples=predictive_assign)
samples_model = predictive_model(model_key, *args, **kwargs)
if self.return_sites is not None:
samples_guide = {
name: value
for name, value in samples_guide.items()
if name in self.return_sites
}
else:
samples_guide = {}

return {
self.mixture_assignment_sitename: assigns,
**samples_guide,
**samples_model, # use samples from model if site in model and guide
}
Loading