diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 4029e7e4c..c15a96eb4 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -23,6 +23,7 @@ from .smc.tempered import tempered_smc from .vi.meanfield_vi import meanfield_vi from .vi.pathfinder import pathfinder +from .vi.svgd import svgd __all__ = [ "__version__", @@ -50,6 +51,7 @@ "tempered_smc", "meanfield_vi", # variational inference "pathfinder", + "svgd", "ess", # diagnostics "rhat", ] diff --git a/blackjax/vi/__init__.py b/blackjax/vi/__init__.py index 131fd6c3a..796e5c0e6 100644 --- a/blackjax/vi/__init__.py +++ b/blackjax/vi/__init__.py @@ -1,3 +1,3 @@ -from . import meanfield_vi, pathfinder +from . import meanfield_vi, pathfinder, svgd -__all__ = ["pathfinder", "meanfield_vi"] +__all__ = ["pathfinder", "meanfield_vi", "svgd"] diff --git a/blackjax/vi/svgd.py b/blackjax/vi/svgd.py new file mode 100644 index 000000000..4f2ffbd36 --- /dev/null +++ b/blackjax/vi/svgd.py @@ -0,0 +1,167 @@ +import functools +from typing import Any, Callable, Dict, NamedTuple + +import jax +import jax.numpy as jnp +import optax +from jax.flatten_util import ravel_pytree + +from blackjax.base import MCMCSamplingAlgorithm +from blackjax.types import PyTree + +__all__ = ["svgd", "rbf_kernel", "update_median_heuristic"] + + +class SVGDState(NamedTuple): + particles: PyTree + kernel_parameters: Dict[str, PyTree] + opt_state: Any + + +def init( + initial_particles: PyTree, + kernel_parameters: Dict[str, Any], + optimizer: optax.GradientTransformation, +) -> SVGDState: + """ + Initializes Stein Variational Gradient Descent Algorithm. + + Parameters + ---------- + initial_particles + Initial set of particles to start the optimization + kernel_paremeters + Arguments to the kernel function + optimizer + Optax compatible optimizer, which conforms to the `optax.GradientTransformation` protocol + """ + opt_state = optimizer.init(initial_particles) + return SVGDState(initial_particles, kernel_parameters, opt_state) + + +def build_kernel(optimizer: optax.GradientTransformation): + def kernel( + state: SVGDState, + grad_logdensity_fn: Callable, + kernel: Callable, + **grad_params, + ) -> SVGDState: + """ + Performs one step of Stein Variational Gradient Descent. + + See Algorithm 1 of :cite:p:`liu2016stein`. + + Parameters + ---------- + state + SVGDState object containing information about previous iteration + grad_logdensity_fn + gradient, or an estimate, of the target log density function to samples approximately from + kernel + positive semi definite kernel + **grad_params + additional parameters for `grad_logdensity_fn` function, for instance a minibatch parameter + on a gradient estimator. + + Returns + ------- + SVGDState containing new particles, optimizer state and kernel parameters. + """ + particles, kernel_params, opt_state = state + kernel = functools.partial(kernel, **kernel_params) + + def phi_star_summand(particle, particle_): + gradient = grad_logdensity_fn(particle, **grad_params) + k, grad_k = jax.value_and_grad(kernel, argnums=0)(particle, particle_) + return jax.tree_util.tree_map(lambda g, gk: -(k * g) - gk, gradient, grad_k) + + functional_gradient = jax.vmap( + lambda p_: jax.tree_util.tree_map( + lambda phi_star: phi_star.mean(axis=0), + jax.vmap(lambda p: phi_star_summand(p, p_))(particles), + ) + )(particles) + + updates, opt_state = optimizer.update(functional_gradient, opt_state, particles) + particles = optax.apply_updates(particles, updates) + + return SVGDState(particles, kernel_params, opt_state) + + return kernel + + +def rbf_kernel(x, y, length_scale=1): + arg = ravel_pytree(jax.tree_util.tree_map(lambda x, y: (x - y) ** 2, x, y))[0] + return jnp.exp(-(1 / length_scale) * arg.sum()) + + +def median_heuristic(kernel_parameters, particles): + particle_array = jax.vmap(lambda p: ravel_pytree(p)[0])(particles) + + def distance(x, y): + return jnp.linalg.norm(jnp.atleast_1d(x - y)) + + vmapped_distance = jax.vmap(jax.vmap(distance, (None, 0)), (0, None)) + A = vmapped_distance(particle_array, particle_array) # Calculate distance matrix + pairwise_distances = A[ + jnp.tril_indices(A.shape[0], k=-1) + ] # Take values below the main diagonal into a vector + median = jnp.median(pairwise_distances) + kernel_parameters["length_scale"] = (median**2) / jnp.log(particle_array.shape[0]) + return kernel_parameters + + +def update_median_heuristic(state: SVGDState) -> SVGDState: + """Median heuristic for setting the bandwidth of RBF kernels. + + A reasonable middle-ground for choosing the `length_scale` of the RBF kernel + is to pick the empirical median of the squared distance between particles. + This strategy is called the median heuristic. + """ + + position, kernel_parameters, opt_state = state + return SVGDState(position, median_heuristic(kernel_parameters, position), opt_state) + + +class svgd: + """Implements the (basic) user interface for the svgd algorithm. + + Parameters + ---------- + grad_logdensity_fn + gradient, or an estimate, of the target log density function to samples approximately from + optimizer + Optax compatible optimizer, which conforms to the `optax.GradientTransformation` protocol + kernel + positive semi definite kernel + update_kernel_parameters + function that updates the kernel parameters given the current state of the particles + + Returns + ------- + A ``MCMCSamplingAlgorithm``. + """ + + init = staticmethod(init) + build_kernel = staticmethod(build_kernel) + + def __new__( + cls, + grad_logdensity_fn: Callable, + optimizer, + kernel: Callable = rbf_kernel, + update_kernel_parameters: Callable = update_median_heuristic, + ): + kernel_ = cls.build_kernel(optimizer) + + def init_fn( + initial_position: PyTree, + kernel_parameters: Dict[str, Any] = {"length_scale": 1.0}, + ): + return cls.init(initial_position, kernel_parameters, optimizer) + + def step_fn(state, **grad_params): + state = kernel_(state, grad_logdensity_fn, kernel, **grad_params) + return update_kernel_parameters(state) + + return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/docs/refs.bib b/docs/refs.bib index 8c61c8bb5..9f1564b22 100644 --- a/docs/refs.bib +++ b/docs/refs.bib @@ -356,3 +356,11 @@ @article{ma2015complete volume={28}, year={2015} } + +@article{liu2016stein, + title={Stein variational gradient descent: A general purpose bayesian inference algorithm}, + author={Liu, Qiang and Wang, Dilin}, + journal={Advances in neural information processing systems}, + volume={29}, + year={2016} +} diff --git a/tests/vi/test_svgd.py b/tests/vi/test_svgd.py new file mode 100644 index 000000000..a0222a863 --- /dev/null +++ b/tests/vi/test_svgd.py @@ -0,0 +1,109 @@ +import functools + +import chex +import jax +import jax.numpy as jnp +import jax.scipy.stats as stats +from absl.testing import absltest +from optax import adam + +import blackjax +from blackjax.vi.svgd import SVGDState, rbf_kernel, update_median_heuristic + + +def svgd_training_loop( + log_p, + initial_position, + initial_kernel_parameters, + kernel, + optimizer, + *, + num_iterations=500, +) -> SVGDState: + svgd = blackjax.svgd(jax.grad(log_p), optimizer, kernel, update_median_heuristic) + state = svgd.init(initial_position, initial_kernel_parameters) + step = jax.jit(svgd.step) # type: ignore[attr-defined] + + for _ in range(num_iterations): + state = step(state) + return state + + +class SvgdTest(chex.TestCase): + def setUp(self): + super().setUp() + self.key = jax.random.PRNGKey(1) + + def test_recover_posterior(self): + # TODO improve testing + """Simple Normal mean test""" + + ndim = 2 + + rng_key_chol, rng_key_observed, rng_key_init = jax.random.split(self.key, 3) + L = jnp.tril(jax.random.normal(rng_key_chol, (ndim, ndim))) + true_mu = jnp.arange(ndim) + true_cov = L @ L.T + true_prec = jnp.linalg.pinv(true_cov) + + def logp_posterior_conjugate_normal_model( + observed, prior_mu, prior_prec, true_prec + ): + n = observed.shape[0] + posterior_cov = jnp.linalg.inv(prior_prec + n * true_prec) + posterior_mu = ( + posterior_cov + @ ( + prior_prec @ prior_mu[:, None] + + n * true_prec @ observed.mean(0)[:, None] + ) + )[:, 0] + return posterior_mu + + def logp_unnormalized_posterior(x, observed, prior_mu, prior_prec, true_cov): + logp = 0.0 + logp += stats.multivariate_normal.logpdf(x, prior_mu, prior_prec) + logp += stats.multivariate_normal.logpdf(observed, x, true_cov).sum() + return logp + + prior_mu = jnp.zeros(ndim) + prior_prec = jnp.eye(ndim) + + # Simulate the data + observed = jax.random.multivariate_normal( + rng_key_observed, true_mu, true_cov, shape=(10_000,) + ) + + logp_model = functools.partial( + logp_unnormalized_posterior, + observed=observed, + prior_mu=prior_mu, + prior_prec=prior_prec, + true_cov=true_cov, + ) + + num_particles = 50 + initial_particles = jax.random.multivariate_normal( + rng_key_init, prior_mu, prior_prec, shape=(num_particles,) + ) + + out = svgd_training_loop( + log_p=logp_model, + initial_position=initial_particles, + initial_kernel_parameters={"length_scale": 1.0}, + kernel=rbf_kernel, + optimizer=adam(0.2), + num_iterations=500, + ) + + posterior_mu = logp_posterior_conjugate_normal_model( + observed, prior_mu, prior_prec, true_prec + ) + + self.assertAlmostEqual( + jnp.linalg.norm(posterior_mu - out.particles.mean(0)), 0.0, delta=1.0 + ) + + +if __name__ == "__main__": + absltest.main()