Skip to content

Commit

Permalink
Merge branch 'main' into rmhmc
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao authored Jun 1, 2023
2 parents ed3241b + 5004f9f commit b2cdc71
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 2 deletions.
2 changes: 2 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .smc.tempered import tempered_smc
from .vi.meanfield_vi import meanfield_vi
from .vi.pathfinder import pathfinder
from .vi.svgd import svgd

__all__ = [
"__version__",
Expand Down Expand Up @@ -50,6 +51,7 @@
"tempered_smc",
"meanfield_vi", # variational inference
"pathfinder",
"svgd",
"ess", # diagnostics
"rhat",
]
4 changes: 2 additions & 2 deletions blackjax/vi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import meanfield_vi, pathfinder
from . import meanfield_vi, pathfinder, svgd

__all__ = ["pathfinder", "meanfield_vi"]
__all__ = ["pathfinder", "meanfield_vi", "svgd"]
167 changes: 167 additions & 0 deletions blackjax/vi/svgd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import functools
from typing import Any, Callable, Dict, NamedTuple

import jax
import jax.numpy as jnp
import optax
from jax.flatten_util import ravel_pytree

from blackjax.base import MCMCSamplingAlgorithm
from blackjax.types import PyTree

__all__ = ["svgd", "rbf_kernel", "update_median_heuristic"]


class SVGDState(NamedTuple):
particles: PyTree
kernel_parameters: Dict[str, PyTree]
opt_state: Any


def init(
initial_particles: PyTree,
kernel_parameters: Dict[str, Any],
optimizer: optax.GradientTransformation,
) -> SVGDState:
"""
Initializes Stein Variational Gradient Descent Algorithm.
Parameters
----------
initial_particles
Initial set of particles to start the optimization
kernel_paremeters
Arguments to the kernel function
optimizer
Optax compatible optimizer, which conforms to the `optax.GradientTransformation` protocol
"""
opt_state = optimizer.init(initial_particles)
return SVGDState(initial_particles, kernel_parameters, opt_state)


def build_kernel(optimizer: optax.GradientTransformation):
def kernel(
state: SVGDState,
grad_logdensity_fn: Callable,
kernel: Callable,
**grad_params,
) -> SVGDState:
"""
Performs one step of Stein Variational Gradient Descent.
See Algorithm 1 of :cite:p:`liu2016stein`.
Parameters
----------
state
SVGDState object containing information about previous iteration
grad_logdensity_fn
gradient, or an estimate, of the target log density function to samples approximately from
kernel
positive semi definite kernel
**grad_params
additional parameters for `grad_logdensity_fn` function, for instance a minibatch parameter
on a gradient estimator.
Returns
-------
SVGDState containing new particles, optimizer state and kernel parameters.
"""
particles, kernel_params, opt_state = state
kernel = functools.partial(kernel, **kernel_params)

def phi_star_summand(particle, particle_):
gradient = grad_logdensity_fn(particle, **grad_params)
k, grad_k = jax.value_and_grad(kernel, argnums=0)(particle, particle_)
return jax.tree_util.tree_map(lambda g, gk: -(k * g) - gk, gradient, grad_k)

functional_gradient = jax.vmap(
lambda p_: jax.tree_util.tree_map(
lambda phi_star: phi_star.mean(axis=0),
jax.vmap(lambda p: phi_star_summand(p, p_))(particles),
)
)(particles)

updates, opt_state = optimizer.update(functional_gradient, opt_state, particles)
particles = optax.apply_updates(particles, updates)

return SVGDState(particles, kernel_params, opt_state)

return kernel


def rbf_kernel(x, y, length_scale=1):
arg = ravel_pytree(jax.tree_util.tree_map(lambda x, y: (x - y) ** 2, x, y))[0]
return jnp.exp(-(1 / length_scale) * arg.sum())


def median_heuristic(kernel_parameters, particles):
particle_array = jax.vmap(lambda p: ravel_pytree(p)[0])(particles)

def distance(x, y):
return jnp.linalg.norm(jnp.atleast_1d(x - y))

vmapped_distance = jax.vmap(jax.vmap(distance, (None, 0)), (0, None))
A = vmapped_distance(particle_array, particle_array) # Calculate distance matrix
pairwise_distances = A[
jnp.tril_indices(A.shape[0], k=-1)
] # Take values below the main diagonal into a vector
median = jnp.median(pairwise_distances)
kernel_parameters["length_scale"] = (median**2) / jnp.log(particle_array.shape[0])
return kernel_parameters


def update_median_heuristic(state: SVGDState) -> SVGDState:
"""Median heuristic for setting the bandwidth of RBF kernels.
A reasonable middle-ground for choosing the `length_scale` of the RBF kernel
is to pick the empirical median of the squared distance between particles.
This strategy is called the median heuristic.
"""

position, kernel_parameters, opt_state = state
return SVGDState(position, median_heuristic(kernel_parameters, position), opt_state)


class svgd:
"""Implements the (basic) user interface for the svgd algorithm.
Parameters
----------
grad_logdensity_fn
gradient, or an estimate, of the target log density function to samples approximately from
optimizer
Optax compatible optimizer, which conforms to the `optax.GradientTransformation` protocol
kernel
positive semi definite kernel
update_kernel_parameters
function that updates the kernel parameters given the current state of the particles
Returns
-------
A ``MCMCSamplingAlgorithm``.
"""

init = staticmethod(init)
build_kernel = staticmethod(build_kernel)

def __new__(
cls,
grad_logdensity_fn: Callable,
optimizer,
kernel: Callable = rbf_kernel,
update_kernel_parameters: Callable = update_median_heuristic,
):
kernel_ = cls.build_kernel(optimizer)

def init_fn(
initial_position: PyTree,
kernel_parameters: Dict[str, Any] = {"length_scale": 1.0},
):
return cls.init(initial_position, kernel_parameters, optimizer)

def step_fn(state, **grad_params):
state = kernel_(state, grad_logdensity_fn, kernel, **grad_params)
return update_kernel_parameters(state)

return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]
8 changes: 8 additions & 0 deletions docs/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,11 @@ @article{ma2015complete
volume={28},
year={2015}
}

@article{liu2016stein,
title={Stein variational gradient descent: A general purpose bayesian inference algorithm},
author={Liu, Qiang and Wang, Dilin},
journal={Advances in neural information processing systems},
volume={29},
year={2016}
}
109 changes: 109 additions & 0 deletions tests/vi/test_svgd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import functools

import chex
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
from absl.testing import absltest
from optax import adam

import blackjax
from blackjax.vi.svgd import SVGDState, rbf_kernel, update_median_heuristic


def svgd_training_loop(
log_p,
initial_position,
initial_kernel_parameters,
kernel,
optimizer,
*,
num_iterations=500,
) -> SVGDState:
svgd = blackjax.svgd(jax.grad(log_p), optimizer, kernel, update_median_heuristic)
state = svgd.init(initial_position, initial_kernel_parameters)
step = jax.jit(svgd.step) # type: ignore[attr-defined]

for _ in range(num_iterations):
state = step(state)
return state


class SvgdTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(1)

def test_recover_posterior(self):
# TODO improve testing
"""Simple Normal mean test"""

ndim = 2

rng_key_chol, rng_key_observed, rng_key_init = jax.random.split(self.key, 3)
L = jnp.tril(jax.random.normal(rng_key_chol, (ndim, ndim)))
true_mu = jnp.arange(ndim)
true_cov = L @ L.T
true_prec = jnp.linalg.pinv(true_cov)

def logp_posterior_conjugate_normal_model(
observed, prior_mu, prior_prec, true_prec
):
n = observed.shape[0]
posterior_cov = jnp.linalg.inv(prior_prec + n * true_prec)
posterior_mu = (
posterior_cov
@ (
prior_prec @ prior_mu[:, None]
+ n * true_prec @ observed.mean(0)[:, None]
)
)[:, 0]
return posterior_mu

def logp_unnormalized_posterior(x, observed, prior_mu, prior_prec, true_cov):
logp = 0.0
logp += stats.multivariate_normal.logpdf(x, prior_mu, prior_prec)
logp += stats.multivariate_normal.logpdf(observed, x, true_cov).sum()
return logp

prior_mu = jnp.zeros(ndim)
prior_prec = jnp.eye(ndim)

# Simulate the data
observed = jax.random.multivariate_normal(
rng_key_observed, true_mu, true_cov, shape=(10_000,)
)

logp_model = functools.partial(
logp_unnormalized_posterior,
observed=observed,
prior_mu=prior_mu,
prior_prec=prior_prec,
true_cov=true_cov,
)

num_particles = 50
initial_particles = jax.random.multivariate_normal(
rng_key_init, prior_mu, prior_prec, shape=(num_particles,)
)

out = svgd_training_loop(
log_p=logp_model,
initial_position=initial_particles,
initial_kernel_parameters={"length_scale": 1.0},
kernel=rbf_kernel,
optimizer=adam(0.2),
num_iterations=500,
)

posterior_mu = logp_posterior_conjugate_normal_model(
observed, prior_mu, prior_prec, true_prec
)

self.assertAlmostEqual(
jnp.linalg.norm(posterior_mu - out.particles.mean(0)), 0.0, delta=1.0
)


if __name__ == "__main__":
absltest.main()

0 comments on commit b2cdc71

Please sign in to comment.