-
Notifications
You must be signed in to change notification settings - Fork 107
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
288 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from . import meanfield_vi, pathfinder | ||
from . import meanfield_vi, pathfinder, svgd | ||
|
||
__all__ = ["pathfinder", "meanfield_vi"] | ||
__all__ = ["pathfinder", "meanfield_vi", "svgd"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
import functools | ||
from typing import Any, Callable, Dict, NamedTuple | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import optax | ||
from jax.flatten_util import ravel_pytree | ||
|
||
from blackjax.base import MCMCSamplingAlgorithm | ||
from blackjax.types import PyTree | ||
|
||
__all__ = ["svgd", "rbf_kernel", "update_median_heuristic"] | ||
|
||
|
||
class SVGDState(NamedTuple): | ||
particles: PyTree | ||
kernel_parameters: Dict[str, PyTree] | ||
opt_state: Any | ||
|
||
|
||
def init( | ||
initial_particles: PyTree, | ||
kernel_parameters: Dict[str, Any], | ||
optimizer: optax.GradientTransformation, | ||
) -> SVGDState: | ||
""" | ||
Initializes Stein Variational Gradient Descent Algorithm. | ||
Parameters | ||
---------- | ||
initial_particles | ||
Initial set of particles to start the optimization | ||
kernel_paremeters | ||
Arguments to the kernel function | ||
optimizer | ||
Optax compatible optimizer, which conforms to the `optax.GradientTransformation` protocol | ||
""" | ||
opt_state = optimizer.init(initial_particles) | ||
return SVGDState(initial_particles, kernel_parameters, opt_state) | ||
|
||
|
||
def build_kernel(optimizer: optax.GradientTransformation): | ||
def kernel( | ||
state: SVGDState, | ||
grad_logdensity_fn: Callable, | ||
kernel: Callable, | ||
**grad_params, | ||
) -> SVGDState: | ||
""" | ||
Performs one step of Stein Variational Gradient Descent. | ||
See Algorithm 1 of :cite:p:`liu2016stein`. | ||
Parameters | ||
---------- | ||
state | ||
SVGDState object containing information about previous iteration | ||
grad_logdensity_fn | ||
gradient, or an estimate, of the target log density function to samples approximately from | ||
kernel | ||
positive semi definite kernel | ||
**grad_params | ||
additional parameters for `grad_logdensity_fn` function, for instance a minibatch parameter | ||
on a gradient estimator. | ||
Returns | ||
------- | ||
SVGDState containing new particles, optimizer state and kernel parameters. | ||
""" | ||
particles, kernel_params, opt_state = state | ||
kernel = functools.partial(kernel, **kernel_params) | ||
|
||
def phi_star_summand(particle, particle_): | ||
gradient = grad_logdensity_fn(particle, **grad_params) | ||
k, grad_k = jax.value_and_grad(kernel, argnums=0)(particle, particle_) | ||
return jax.tree_util.tree_map(lambda g, gk: -(k * g) - gk, gradient, grad_k) | ||
|
||
functional_gradient = jax.vmap( | ||
lambda p_: jax.tree_util.tree_map( | ||
lambda phi_star: phi_star.mean(axis=0), | ||
jax.vmap(lambda p: phi_star_summand(p, p_))(particles), | ||
) | ||
)(particles) | ||
|
||
updates, opt_state = optimizer.update(functional_gradient, opt_state, particles) | ||
particles = optax.apply_updates(particles, updates) | ||
|
||
return SVGDState(particles, kernel_params, opt_state) | ||
|
||
return kernel | ||
|
||
|
||
def rbf_kernel(x, y, length_scale=1): | ||
arg = ravel_pytree(jax.tree_util.tree_map(lambda x, y: (x - y) ** 2, x, y))[0] | ||
return jnp.exp(-(1 / length_scale) * arg.sum()) | ||
|
||
|
||
def median_heuristic(kernel_parameters, particles): | ||
particle_array = jax.vmap(lambda p: ravel_pytree(p)[0])(particles) | ||
|
||
def distance(x, y): | ||
return jnp.linalg.norm(jnp.atleast_1d(x - y)) | ||
|
||
vmapped_distance = jax.vmap(jax.vmap(distance, (None, 0)), (0, None)) | ||
A = vmapped_distance(particle_array, particle_array) # Calculate distance matrix | ||
pairwise_distances = A[ | ||
jnp.tril_indices(A.shape[0], k=-1) | ||
] # Take values below the main diagonal into a vector | ||
median = jnp.median(pairwise_distances) | ||
kernel_parameters["length_scale"] = (median**2) / jnp.log(particle_array.shape[0]) | ||
return kernel_parameters | ||
|
||
|
||
def update_median_heuristic(state: SVGDState) -> SVGDState: | ||
"""Median heuristic for setting the bandwidth of RBF kernels. | ||
A reasonable middle-ground for choosing the `length_scale` of the RBF kernel | ||
is to pick the empirical median of the squared distance between particles. | ||
This strategy is called the median heuristic. | ||
""" | ||
|
||
position, kernel_parameters, opt_state = state | ||
return SVGDState(position, median_heuristic(kernel_parameters, position), opt_state) | ||
|
||
|
||
class svgd: | ||
"""Implements the (basic) user interface for the svgd algorithm. | ||
Parameters | ||
---------- | ||
grad_logdensity_fn | ||
gradient, or an estimate, of the target log density function to samples approximately from | ||
optimizer | ||
Optax compatible optimizer, which conforms to the `optax.GradientTransformation` protocol | ||
kernel | ||
positive semi definite kernel | ||
update_kernel_parameters | ||
function that updates the kernel parameters given the current state of the particles | ||
Returns | ||
------- | ||
A ``MCMCSamplingAlgorithm``. | ||
""" | ||
|
||
init = staticmethod(init) | ||
build_kernel = staticmethod(build_kernel) | ||
|
||
def __new__( | ||
cls, | ||
grad_logdensity_fn: Callable, | ||
optimizer, | ||
kernel: Callable = rbf_kernel, | ||
update_kernel_parameters: Callable = update_median_heuristic, | ||
): | ||
kernel_ = cls.build_kernel(optimizer) | ||
|
||
def init_fn( | ||
initial_position: PyTree, | ||
kernel_parameters: Dict[str, Any] = {"length_scale": 1.0}, | ||
): | ||
return cls.init(initial_position, kernel_parameters, optimizer) | ||
|
||
def step_fn(state, **grad_params): | ||
state = kernel_(state, grad_logdensity_fn, kernel, **grad_params) | ||
return update_kernel_parameters(state) | ||
|
||
return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import functools | ||
|
||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
import jax.scipy.stats as stats | ||
from absl.testing import absltest | ||
from optax import adam | ||
|
||
import blackjax | ||
from blackjax.vi.svgd import SVGDState, rbf_kernel, update_median_heuristic | ||
|
||
|
||
def svgd_training_loop( | ||
log_p, | ||
initial_position, | ||
initial_kernel_parameters, | ||
kernel, | ||
optimizer, | ||
*, | ||
num_iterations=500, | ||
) -> SVGDState: | ||
svgd = blackjax.svgd(jax.grad(log_p), optimizer, kernel, update_median_heuristic) | ||
state = svgd.init(initial_position, initial_kernel_parameters) | ||
step = jax.jit(svgd.step) # type: ignore[attr-defined] | ||
|
||
for _ in range(num_iterations): | ||
state = step(state) | ||
return state | ||
|
||
|
||
class SvgdTest(chex.TestCase): | ||
def setUp(self): | ||
super().setUp() | ||
self.key = jax.random.PRNGKey(1) | ||
|
||
def test_recover_posterior(self): | ||
# TODO improve testing | ||
"""Simple Normal mean test""" | ||
|
||
ndim = 2 | ||
|
||
rng_key_chol, rng_key_observed, rng_key_init = jax.random.split(self.key, 3) | ||
L = jnp.tril(jax.random.normal(rng_key_chol, (ndim, ndim))) | ||
true_mu = jnp.arange(ndim) | ||
true_cov = L @ L.T | ||
true_prec = jnp.linalg.pinv(true_cov) | ||
|
||
def logp_posterior_conjugate_normal_model( | ||
observed, prior_mu, prior_prec, true_prec | ||
): | ||
n = observed.shape[0] | ||
posterior_cov = jnp.linalg.inv(prior_prec + n * true_prec) | ||
posterior_mu = ( | ||
posterior_cov | ||
@ ( | ||
prior_prec @ prior_mu[:, None] | ||
+ n * true_prec @ observed.mean(0)[:, None] | ||
) | ||
)[:, 0] | ||
return posterior_mu | ||
|
||
def logp_unnormalized_posterior(x, observed, prior_mu, prior_prec, true_cov): | ||
logp = 0.0 | ||
logp += stats.multivariate_normal.logpdf(x, prior_mu, prior_prec) | ||
logp += stats.multivariate_normal.logpdf(observed, x, true_cov).sum() | ||
return logp | ||
|
||
prior_mu = jnp.zeros(ndim) | ||
prior_prec = jnp.eye(ndim) | ||
|
||
# Simulate the data | ||
observed = jax.random.multivariate_normal( | ||
rng_key_observed, true_mu, true_cov, shape=(10_000,) | ||
) | ||
|
||
logp_model = functools.partial( | ||
logp_unnormalized_posterior, | ||
observed=observed, | ||
prior_mu=prior_mu, | ||
prior_prec=prior_prec, | ||
true_cov=true_cov, | ||
) | ||
|
||
num_particles = 50 | ||
initial_particles = jax.random.multivariate_normal( | ||
rng_key_init, prior_mu, prior_prec, shape=(num_particles,) | ||
) | ||
|
||
out = svgd_training_loop( | ||
log_p=logp_model, | ||
initial_position=initial_particles, | ||
initial_kernel_parameters={"length_scale": 1.0}, | ||
kernel=rbf_kernel, | ||
optimizer=adam(0.2), | ||
num_iterations=500, | ||
) | ||
|
||
posterior_mu = logp_posterior_conjugate_normal_model( | ||
observed, prior_mu, prior_prec, true_prec | ||
) | ||
|
||
self.assertAlmostEqual( | ||
jnp.linalg.norm(posterior_mu - out.particles.mean(0)), 0.0, delta=1.0 | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |