Skip to content

Commit

Permalink
SMC Inner kernel tuning (#595)
Browse files Browse the repository at this point in the history
* inner kernel tuning, tests, and some common strategies

* Adding imports

* pre-commit

* code review updates

* Adding Chex tests

* line alignment comment

* Adding particles_as_rows test

* Modifying implementation of particles_as_rows

* pre-commit

* change in inverse_mass_matrix from particles implementation

* replacing particles_as_rows_test

---------

Co-authored-by: Junpeng Lao <[email protected]>
  • Loading branch information
ciguaran and junpenglao committed Mar 12, 2024
1 parent 5fbc6f0 commit d1e7014
Show file tree
Hide file tree
Showing 9 changed files with 675 additions and 28 deletions.
2 changes: 2 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .sgmcmc.sgld import sgld
from .sgmcmc.sgnht import sgnht
from .smc.adaptive_tempered import adaptive_tempered_smc
from .smc.inner_kernel_tuning import inner_kernel_tuning
from .smc.tempered import tempered_smc
from .vi.meanfield_vi import meanfield_vi
from .vi.pathfinder import pathfinder
Expand Down Expand Up @@ -57,6 +58,7 @@
"mclmc_find_L_and_step_size", # mclmc adaptation
"adaptive_tempered_smc", # smc
"tempered_smc",
"inner_kernel_tuning",
"meanfield_vi", # variational inference
"pathfinder",
"schrodinger_follmer",
Expand Down
4 changes: 2 additions & 2 deletions blackjax/smc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import adaptive_tempered, tempered
from . import adaptive_tempered, inner_kernel_tuning, tempered

__all__ = ["adaptive_tempered", "tempered"]
__all__ = ["adaptive_tempered", "tempered", "inner_kernel_tuning"]
150 changes: 150 additions & 0 deletions blackjax/smc/inner_kernel_tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from typing import Callable, Dict, NamedTuple, Tuple, Union

from blackjax.base import SamplingAlgorithm
from blackjax.smc.adaptive_tempered import adaptive_tempered_smc
from blackjax.smc.base import SMCInfo, SMCState
from blackjax.smc.tempered import tempered_smc
from blackjax.types import ArrayTree, PRNGKey


class StateWithParameterOverride(NamedTuple):
sampler_state: ArrayTree
parameter_override: ArrayTree


def init(alg_init_fn, position, initial_parameter_value):
return StateWithParameterOverride(alg_init_fn(position), initial_parameter_value)


def build_kernel(
smc_algorithm,
logprior_fn: Callable,
loglikelihood_fn: Callable,
mcmc_factory: Callable,
mcmc_init_fn: Callable,
mcmc_parameters: Dict,
resampling_fn: Callable,
mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree],
num_mcmc_steps: int = 10,
**extra_parameters,
) -> Callable:
"""In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner
MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC,
based on particles. The parameter type must be a valid JAX type.
Parameters
----------
smc_algorithm
Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of
a sampling algorithm that returns an SMCState and SMCInfo pair).
logprior_fn
A function that computes the log density of the prior distribution
loglikelihood_fn
A function that returns the probability at a given position.
mcmc_factory
A callable that can construct an inner kernel out of the newly-computed parameter
mcmc_init_fn
A callable that initializes the inner kernel
mcmc_parameters
Other (fixed across SMC iterations) parameters for the inner kernel
mcmc_parameter_update_fn
A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration.
extra_parameters:
parameters to be used for the creation of the smc_algorithm.
"""

def kernel(
rng_key: PRNGKey, state: StateWithParameterOverride, **extra_step_parameters
) -> Tuple[StateWithParameterOverride, SMCInfo]:
step_fn = smc_algorithm(
logprior_fn=logprior_fn,
loglikelihood_fn=loglikelihood_fn,
mcmc_step_fn=mcmc_factory(state.parameter_override),
mcmc_init_fn=mcmc_init_fn,
mcmc_parameters=mcmc_parameters,
resampling_fn=resampling_fn,
num_mcmc_steps=num_mcmc_steps,
**extra_parameters,
).step
new_state, info = step_fn(rng_key, state.sampler_state, **extra_step_parameters)
new_parameter_override = mcmc_parameter_update_fn(new_state, info)
return StateWithParameterOverride(new_state, new_parameter_override), info

return kernel


class inner_kernel_tuning:
"""In the context of an SMC sampler (whose step_fn returning state
has a .particles attribute), there's an inner MCMC that is used
to perturbate/update each of the particles. This adaptation tunes some
parameter of that MCMC, based on particles.
The parameter type must be a valid JAX type.
Parameters
----------
smc_algorithm
Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of
a sampling algorithm that returns an SMCState and SMCInfo pair).
logprior_fn
A function that computes the log density of the prior distribution
loglikelihood_fn
A function that returns the probability at a given position.
mcmc_factory
A callable that can construct an inner kernel out of the newly-computed parameter
mcmc_init_fn
A callable that initializes the inner kernel
mcmc_parameters
Other (fixed across SMC iterations) parameters for the inner kernel step
mcmc_parameter_update_fn
A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the
inner kernel in i+1 iteration.
initial_parameter_value
Paramter to be used by the mcmc_factory before the first iteration.
extra_parameters:
parameters to be used for the creation of the smc_algorithm.
Returns
-------
A ``SamplingAlgorithm``.
"""

init = staticmethod(init)
build_kernel = staticmethod(build_kernel)

def __new__( # type: ignore[misc]
cls,
smc_algorithm: Union[adaptive_tempered_smc, tempered_smc],
logprior_fn: Callable,
loglikelihood_fn: Callable,
mcmc_factory: Callable,
mcmc_init_fn: Callable,
mcmc_parameters: Dict,
resampling_fn: Callable,
mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree],
initial_parameter_value,
num_mcmc_steps: int = 10,
**extra_parameters,
) -> SamplingAlgorithm:
kernel = cls.build_kernel(
smc_algorithm,
logprior_fn,
loglikelihood_fn,
mcmc_factory,
mcmc_init_fn,
mcmc_parameters,
resampling_fn,
mcmc_parameter_update_fn,
num_mcmc_steps,
**extra_parameters,
)

def init_fn(position):
return cls.init(smc_algorithm.init, position, initial_parameter_value)

def step_fn(
rng_key: PRNGKey, state, **extra_step_parameters
) -> Tuple[StateWithParameterOverride, SMCInfo]:
return kernel(rng_key, state, **extra_step_parameters)

return SamplingAlgorithm(init_fn, step_fn)
Empty file added blackjax/smc/tuning/__init__.py
Empty file.
46 changes: 46 additions & 0 deletions blackjax/smc/tuning/from_kernel_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
strategies to tune the parameters of mcmc kernels
used within smc, based on MCMC states
"""
import jax
import jax.numpy as jnp

__all__ = ["update_scale_from_acceptance_rate"]


def update_scale_from_acceptance_rate(
scales: jax.Array,
acceptance_rates: jax.Array,
target_acceptance_rate: float = 0.234,
) -> jax.Array:
"""
Given N chains from some MCMC algorithm like Random Walk Metropolis
and N scale factors, each associated to a different chain.
Updates the scale factors taking into account acceptance rates and
the average acceptance rate.
Under certain assumptions it is known that the optimal acceptance rate
of Metropolis Hastings is 0.4 for 1 dimension and converges to
0.234 in infinite dimensions. In practice, 0.234 is a reasonable
assumption for 5 or more dimensions.
If certain chain is below optimal acceptance rate, its scale will decrease
and if its above, its scale will increase,
-------
Parameters
----------
scales
(n_chains) array consisting of N scale factors, associated to N markov chains
acceptance_rates
(n_chains) acceptance rate of the N markov chains
target_acceptance_rate
a float with a desirable acceptance rate for the chains.
Returns
-------
(n_chains) new scales, with the aim of getting acceptance rates closer to target
if the chains were to be run again.
"""
chain_scales = jnp.exp(jnp.log(scales) + acceptance_rates - target_acceptance_rate)
return 0.5 * (chain_scales + chain_scales.mean())
49 changes: 49 additions & 0 deletions blackjax/smc/tuning/from_particles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
strategies to tune the parameters of mcmc kernels
used within SMC, based on particles.
"""
import jax
import jax.numpy as jnp
from jax._src.flatten_util import ravel_pytree

from blackjax.types import Array

__all__ = [
"particles_means",
"particles_stds",
"particles_covariance_matrix",
"mass_matrix_from_particles",
]


def particles_stds(particles):
return jnp.std(particles_as_rows(particles), axis=0)


def particles_means(particles):
return jnp.mean(particles_as_rows(particles), axis=0)


def particles_covariance_matrix(particles):
return jnp.cov(particles_as_rows(particles), ddof=0, rowvar=False)


def mass_matrix_from_particles(particles) -> Array:
"""
Implements tuning from section 3.1 from https://arxiv.org/pdf/1808.07730.pdf
Computing a mass matrix to be used in HMC from particles.
Given the particles covariance matrix, set all non-diagonal elements as zero,
take the inverse, and keep the diagonal.
Returns
-------
A mass Matrix
"""
return jnp.diag(1.0 / jnp.var(particles_as_rows(particles), axis=0))


def particles_as_rows(particles):
"""
Adds end dimension for single-dimension variables, and then represents multivariables
as a matrix where each column is a variable, each row a particle.
"""
return jax.vmap(lambda x: ravel_pytree(x)[0])(particles)
37 changes: 37 additions & 0 deletions tests/smc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import chex
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np


class SMCLinearRegressionTestCase(chex.TestCase):
def logdensity_fn(self, log_scale, coefs, preds, x):
"""Linear regression"""
scale = jnp.exp(log_scale)
y = jnp.dot(x, coefs)
logpdf = stats.norm.logpdf(preds, y, scale)
return jnp.sum(logpdf)

def particles_prior_loglikelihood(self):
num_particles = 100

x_data = np.random.normal(0, 1, size=(1000, 1))
y_data = 3 * x_data + np.random.normal(size=x_data.shape)
observations = {"x": x_data, "preds": y_data}

logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf(
x["coefs"]
)
loglikelihood_fn = lambda x: self.logdensity_fn(**x, **observations)

log_scale_init = np.random.randn(num_particles)
coeffs_init = np.random.randn(num_particles)
init_particles = {"log_scale": log_scale_init, "coefs": coeffs_init}

return init_particles, logprior_fn, loglikelihood_fn

def assert_linear_regression_test_case(self, result):
np.testing.assert_allclose(
np.mean(np.exp(result.particles["log_scale"])), 1.0, rtol=1e-1
)
np.testing.assert_allclose(np.mean(result.particles["coefs"]), 3.0, rtol=1e-1)
Loading

0 comments on commit d1e7014

Please sign in to comment.