Skip to content

Commit

Permalink
Refactoring SteinVI (pyro-ppl#1883)
Browse files Browse the repository at this point in the history
 1. Added a specialized constructor for SVGD. 
    - The SVGD constructor does not require a guide because it's always a delta 
    - It correctly sets the scaling on the attractive force to `1/m` for `m` particles.
    - It doesn't allow users to change `num_elbo_particles` because it's always 1 for SVGD.
 2. Added a `setup_run` method to SteinVI.
    - The method encapsulates a step of SteinVI, which inheriting constructors can manipulate.
 3. Added a constructor for ASVGD.
    -  ASVGD introduces an annealing schedule on the attractive force.
    - This inherits from SVGD and overwrites `setup_run` to change the `loss_temperature`.
 3. Removed [Jacobian projection from the Stein force](https://github.com/pyro-ppl/numpyro/blob/master/numpyro/contrib/einstein/steinvi.py#L337-L339)
    - The projection is unnecessary as we do not allow `AutoIAFNormal`, `AutoBNAFNormal`, `AutoDAIS`, `AutoSemiDAIS` and `AutoSurrogateLikelihoodDAIS`.
    - This simplifies the force computation to attractive+repulsive.

### Misc changes
1. The normalization factor in the `ProbabilityProductKernel` has been removed. The kernel is still a proper kernel; however, this version avoids vanishing/exploding when the guide variances deviate from 1 for "high" dimensional models.
2. Added a rng_key to the kernel. This is convenient when experimenting with kernels.
3. Removed the `enum` parameter as it is currently unsupported.
4. Kernel smoothing of the attractive force and the repulsion is now computed on particles in unconstraint space, consistent with particles moving in unconstraint space. **NB: Constraint particles are removed from the test**.
5. Added the manual computation for kernel tests as comments.
  • Loading branch information
OlaRonning authored Oct 19, 2024
1 parent d867c54 commit 8ace34f
Show file tree
Hide file tree
Showing 11 changed files with 755 additions and 405 deletions.
7 changes: 3 additions & 4 deletions docs/source/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,20 @@ The framework currently supports several kernels, including:
- `RandomFeatureKernel`
- `MixtureKernel`
- `GraphicalKernel`
- `ProbabilityProductKernel`

For example, usage see:

- The `Bayesian neural network example <https://num.pyro.ai/en/latest/examples/stein_bnn.html>`_
- The `Bayesian neural network example <https://num.pyro.ai/en/latest/examples/stein_bnn.html>`_.

**References**

1. *Stein's Method Meets Statistics: A Review of Some Recent Developments* (2021)
Andreas Anastasiou, Alessandro Barp, François-Xavier Briol, Bruno Ebner,
Robert E. Gaunt, Fatemeh Ghaderinezhad, Jackson Gorham, Arthur Gretton,
Christophe Ley, Qiang Liu, Lester Mackey, Chris. J. Oates, Gesine Reinert,
Yvik Swan. https://arxiv.org/abs/2105.03481
Yvik Swan.

2. *Stein variational gradient descent: A general-purpose Bayesian inference algorithm* (2016)
2. *Stein Variational Gradient Descent: A General-Purpose Bayesian Inference Algorithm* (2016)
Qiang Liu, Dilin Wang. NeurIPS

3. *Nonlinear Stein Variational Gradient Descent for Learning Diversified Mixture Models* (2019)
Expand Down
87 changes: 37 additions & 50 deletions examples/stein_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,10 @@
import numpy as np
from sklearn.model_selection import train_test_split

import jax
from jax import random
import jax.numpy as jnp

import numpyro
from numpyro import deterministic
from numpyro.contrib.einstein import IMQKernel, SteinVI
from numpyro.contrib.einstein.mixture_guide_predictive import MixtureGuidePredictive
from jax import config, nn, numpy as jnp, random

from numpyro import deterministic, plate, sample, set_platform, subsample
from numpyro.contrib.einstein import MixtureGuidePredictive, RBFKernel, SteinVI
from numpyro.distributions import Gamma, Normal
from numpyro.examples.datasets import BOSTON_HOUSING, load_dataset
from numpyro.infer import init_to_uniform
Expand Down Expand Up @@ -54,23 +50,23 @@ def normalize(val, mean=None, std=None):
return (val - mean) / std, mean, std


def model(x, y=None, hidden_dim=50, subsample_size=100):
def model(x, y=None, hidden_dim=50, sub_size=100):
"""BNN described in section 5 of [1].
**References:**
1. *Stein variational gradient descent: A general purpose bayesian inference algorithm*
1. *Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm*
Qiang Liu and Dilin Wang (2016).
"""

prec_nn = numpyro.sample(
prec_nn = sample(
"prec_nn", Gamma(1.0, 0.1)
) # hyper prior for precision of nn weights and biases

n, m = x.shape

with numpyro.plate("l1_hidden", hidden_dim, dim=-1):
with plate("l1_hidden", hidden_dim, dim=-1):
# prior l1 bias term
b1 = numpyro.sample(
b1 = sample(
"nn_b1",
Normal(
0.0,
Expand All @@ -79,38 +75,33 @@ def model(x, y=None, hidden_dim=50, subsample_size=100):
)
assert b1.shape == (hidden_dim,)

with numpyro.plate("l1_feat", m, dim=-2):
w1 = numpyro.sample(
with plate("l1_feat", m, dim=-2):
w1 = sample(
"nn_w1", Normal(0.0, 1.0 / jnp.sqrt(prec_nn))
) # prior on l1 weights
assert w1.shape == (m, hidden_dim)

with numpyro.plate("l2_hidden", hidden_dim, dim=-1):
w2 = numpyro.sample(
with plate("l2_hidden", hidden_dim, dim=-1):
w2 = sample(
"nn_w2", Normal(0.0, 1.0 / jnp.sqrt(prec_nn))
) # prior on output weights

b2 = numpyro.sample(
b2 = sample(
"nn_b2", Normal(0.0, 1.0 / jnp.sqrt(prec_nn))
) # prior on output bias term

# precision prior on observations
prec_obs = numpyro.sample("prec_obs", Gamma(1.0, 0.1))
with numpyro.plate(
"data",
x.shape[0],
subsample_size=subsample_size,
dim=-1,
):
batch_x = numpyro.subsample(x, event_dim=1)
prec_obs = sample("prec_obs", Gamma(1.0, 0.1))
with plate("data", x.shape[0], subsample_size=sub_size, dim=-1):
batch_x = subsample(x, event_dim=1)
if y is not None:
batch_y = numpyro.subsample(y, event_dim=0)
batch_y = subsample(y, event_dim=0)
else:
batch_y = y

loc_y = deterministic("y_pred", jnp.maximum(batch_x @ w1 + b1, 0) @ w2 + b2)
loc_y = deterministic("y_pred", nn.relu(batch_x @ w1 + b1) @ w2 + b2)

numpyro.sample(
sample(
"y",
Normal(
loc_y, 1.0 / jnp.sqrt(prec_obs)
Expand All @@ -123,34 +114,33 @@ def main(args):
data = load_data()

inf_key, pred_key, data_key = random.split(random.PRNGKey(args.rng_key), 3)
# normalize data and labels to zero mean unit variance!
# Normalize features to zero mean unit variance.
x, xtr_mean, xtr_std = normalize(data.xtr)
y, ytr_mean, ytr_std = normalize(data.ytr)

rng_key, inf_key = random.split(inf_key)

# We find that SteinVI benefits from a small radius when inferring BNNs.
guide = AutoNormal(model, init_loc_fn=partial(init_to_uniform, radius=0.1))

stein = SteinVI(
model,
guide,
Adagrad(0.05),
IMQKernel(),
# ProbabilityProductKernel(guide=guide, scale=1.),
Adagrad(0.5),
RBFKernel(),
repulsion_temperature=args.repulsion,
num_stein_particles=args.num_stein_particles,
num_elbo_particles=args.num_elbo_particles,
)
start = time()

# use keyword params for static (shape etc.)!
# Use keyword params for static (shape etc.)
result = stein.run(
rng_key,
args.max_iter,
x,
y,
data.ytr,
hidden_dim=args.hidden_dim,
subsample_size=args.subsample_size,
sub_size=args.subsample_size,
progress_bar=args.progress_bar,
)
time_taken = time() - start
Expand All @@ -164,39 +154,36 @@ def main(args):
)
xte, _, _ = normalize(
data.xte, xtr_mean, xtr_std
) # use train data statistics when accessing generalization
preds = pred(
pred_key, xte, subsample_size=xte.shape[0], hidden_dim=args.hidden_dim
)["y_pred"]
) # Use train data statistics when accessing generalization.
n = xte.shape[0]
y_preds = pred(pred_key, xte, sub_size=n, hidden_dim=args.hidden_dim)["y_pred"]

y_pred = preds * ytr_std + ytr_mean
rmse = jnp.sqrt(jnp.mean((y_pred.mean(0) - data.yte) ** 2))
mean_pred = y_preds.mean(0)
rmse = jnp.sqrt(jnp.mean((mean_pred - data.yte) ** 2))

print(rf"Time taken: {datetime.timedelta(seconds=int(time_taken))}")
print(rf"RMSE: {rmse:.2f}")

# compute mean prediction and confidence interval around median
mean_prediction = y_pred.mean(0)

ran = np.arange(mean_prediction.shape[0])
percentiles = np.percentile(preds * ytr_std + ytr_mean, [5.0, 95.0], axis=0)
percentiles = jnp.percentile(y_preds, jnp.array([5.0, 95.0]), axis=0)

# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
ran = np.arange(mean_pred.shape[0])
ax.add_collection(
LineCollection(
zip(zip(ran, percentiles[0]), zip(ran, percentiles[1])), colors="lightblue"
)
)
ax.plot(data.yte, "kx", label="y true")
ax.plot(mean_prediction, "ko", label="y pred")
ax.plot(mean_pred, "ko", label="y pred")
ax.set(xlabel="example", ylabel="y", title="Mean predictions with 90% CI")
ax.legend()
fig.savefig("stein_bnn.pdf")


if __name__ == "__main__":
jax.config.update("jax_debug_nans", True)
config.update("jax_debug_nans", True)

parser = argparse.ArgumentParser()
parser.add_argument("--subsample-size", type=int, default=100)
Expand All @@ -212,6 +199,6 @@ def main(args):

args = parser.parse_args()

numpyro.set_platform(args.device)
set_platform(args.device)

main(args)
16 changes: 9 additions & 7 deletions numpyro/contrib/einstein/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@
RBFKernel,
)
from numpyro.contrib.einstein.stein_loss import SteinLoss
from numpyro.contrib.einstein.steinvi import SteinVI
from numpyro.contrib.einstein.steinvi import ASVGD, SVGD, SteinVI

__all__ = [
"SteinVI",
"SteinLoss",
"RBFKernel",
"ASVGD",
"GraphicalKernel",
"IMQKernel",
"LinearKernel",
"RandomFeatureKernel",
"GraphicalKernel",
"MixtureGuidePredictive",
"MixtureKernel",
"RandomFeatureKernel",
"RBFKernel",
"ProbabilityProductKernel",
"MixtureGuidePredictive",
"SVGD",
"SteinVI",
"SteinLoss",
]
23 changes: 18 additions & 5 deletions numpyro/contrib/einstein/mixture_guide_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,26 @@
from collections.abc import Callable, Sequence
from functools import partial
from typing import Optional
import warnings

import jax
from jax import numpy as jnp, random, vmap
from jax import numpy as jnp, random, tree, vmap

from numpyro.handlers import substitute
from numpyro.infer import Predictive
from numpyro.infer.autoguide import AutoGuide
from numpyro.infer.util import _predictive
from numpyro.util import find_stack_level


class MixtureGuidePredictive:
"""(EXPERIMENTAL INTERFACE) This class constructs the predictive distribution for
:class:`numpyro.contrib.einstein.steinvi.SteinVi`
:class:`numpyro.contrib.einstein.steinvi.SteinVi`.
.. Note:: For single mixture component use numpyro.infer.Predictive.
.. Note:: For :class:`numpyro.contrib.einstein.steinvi.SVGD` and :class:`numpyro.contrib.einstein.steinvi.ASVGD` use
:class:`numpyro.infer.util.Predictive`.
.. warning::
The `MixtureGuidePredictive` is experimental and will likely be replaced by
:class:`numpyro.infer.util.Predictive` in the future.
Expand All @@ -44,6 +49,14 @@ def __init__(
return_sites: Optional[Sequence[str]] = None,
mixture_assignment_sitename="mixture_assignments",
):
if isinstance(guide, AutoGuide):
guide_name = guide.__class__.__name__
if guide_name == "AutoDelta":
warnings.warn(
"Use numpyro.inter.Predictive with `batch_ndims=1` for ASVGD and SVGD.",
stacklevel=find_stack_level(),
)

self.model_predictive = partial(
Predictive,
model=model,
Expand All @@ -63,7 +76,7 @@ def __init__(

self.guide = guide
self.return_sites = return_sites
self.num_mixture_components = jnp.shape(jax.tree.flatten(params)[0][0])[0]
self.num_mixture_components = jnp.shape(tree.flatten(params)[0][0])[0]
self.mixture_assignment_sitename = mixture_assignment_sitename

def _call_with_params(self, rng_key, params, args, kwargs):
Expand Down Expand Up @@ -99,7 +112,7 @@ def __call__(self, rng_key, *args, **kwargs):
minval=0,
maxval=self.num_mixture_components,
)
predictive_assign = jax.tree.map(
predictive_assign = tree.map(
lambda arr: vmap(lambda i, assign: arr[i, assign])(
jnp.arange(self._batch_shape[0]), assigns
),
Expand Down
Loading

0 comments on commit 8ace34f

Please sign in to comment.