-
Notifications
You must be signed in to change notification settings - Fork 108
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* inner kernel tuning, tests, and some common strategies * Adding imports * pre-commit * code review updates * Adding Chex tests * line alignment comment * Adding particles_as_rows test * Modifying implementation of particles_as_rows * pre-commit * change in inverse_mass_matrix from particles implementation * replacing particles_as_rows_test --------- Co-authored-by: Junpeng Lao <[email protected]>
- Loading branch information
1 parent
5fbc6f0
commit d1e7014
Showing
9 changed files
with
675 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from . import adaptive_tempered, tempered | ||
from . import adaptive_tempered, inner_kernel_tuning, tempered | ||
|
||
__all__ = ["adaptive_tempered", "tempered"] | ||
__all__ = ["adaptive_tempered", "tempered", "inner_kernel_tuning"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
from typing import Callable, Dict, NamedTuple, Tuple, Union | ||
|
||
from blackjax.base import SamplingAlgorithm | ||
from blackjax.smc.adaptive_tempered import adaptive_tempered_smc | ||
from blackjax.smc.base import SMCInfo, SMCState | ||
from blackjax.smc.tempered import tempered_smc | ||
from blackjax.types import ArrayTree, PRNGKey | ||
|
||
|
||
class StateWithParameterOverride(NamedTuple): | ||
sampler_state: ArrayTree | ||
parameter_override: ArrayTree | ||
|
||
|
||
def init(alg_init_fn, position, initial_parameter_value): | ||
return StateWithParameterOverride(alg_init_fn(position), initial_parameter_value) | ||
|
||
|
||
def build_kernel( | ||
smc_algorithm, | ||
logprior_fn: Callable, | ||
loglikelihood_fn: Callable, | ||
mcmc_factory: Callable, | ||
mcmc_init_fn: Callable, | ||
mcmc_parameters: Dict, | ||
resampling_fn: Callable, | ||
mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree], | ||
num_mcmc_steps: int = 10, | ||
**extra_parameters, | ||
) -> Callable: | ||
"""In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner | ||
MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC, | ||
based on particles. The parameter type must be a valid JAX type. | ||
Parameters | ||
---------- | ||
smc_algorithm | ||
Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of | ||
a sampling algorithm that returns an SMCState and SMCInfo pair). | ||
logprior_fn | ||
A function that computes the log density of the prior distribution | ||
loglikelihood_fn | ||
A function that returns the probability at a given position. | ||
mcmc_factory | ||
A callable that can construct an inner kernel out of the newly-computed parameter | ||
mcmc_init_fn | ||
A callable that initializes the inner kernel | ||
mcmc_parameters | ||
Other (fixed across SMC iterations) parameters for the inner kernel | ||
mcmc_parameter_update_fn | ||
A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration. | ||
extra_parameters: | ||
parameters to be used for the creation of the smc_algorithm. | ||
""" | ||
|
||
def kernel( | ||
rng_key: PRNGKey, state: StateWithParameterOverride, **extra_step_parameters | ||
) -> Tuple[StateWithParameterOverride, SMCInfo]: | ||
step_fn = smc_algorithm( | ||
logprior_fn=logprior_fn, | ||
loglikelihood_fn=loglikelihood_fn, | ||
mcmc_step_fn=mcmc_factory(state.parameter_override), | ||
mcmc_init_fn=mcmc_init_fn, | ||
mcmc_parameters=mcmc_parameters, | ||
resampling_fn=resampling_fn, | ||
num_mcmc_steps=num_mcmc_steps, | ||
**extra_parameters, | ||
).step | ||
new_state, info = step_fn(rng_key, state.sampler_state, **extra_step_parameters) | ||
new_parameter_override = mcmc_parameter_update_fn(new_state, info) | ||
return StateWithParameterOverride(new_state, new_parameter_override), info | ||
|
||
return kernel | ||
|
||
|
||
class inner_kernel_tuning: | ||
"""In the context of an SMC sampler (whose step_fn returning state | ||
has a .particles attribute), there's an inner MCMC that is used | ||
to perturbate/update each of the particles. This adaptation tunes some | ||
parameter of that MCMC, based on particles. | ||
The parameter type must be a valid JAX type. | ||
Parameters | ||
---------- | ||
smc_algorithm | ||
Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of | ||
a sampling algorithm that returns an SMCState and SMCInfo pair). | ||
logprior_fn | ||
A function that computes the log density of the prior distribution | ||
loglikelihood_fn | ||
A function that returns the probability at a given position. | ||
mcmc_factory | ||
A callable that can construct an inner kernel out of the newly-computed parameter | ||
mcmc_init_fn | ||
A callable that initializes the inner kernel | ||
mcmc_parameters | ||
Other (fixed across SMC iterations) parameters for the inner kernel step | ||
mcmc_parameter_update_fn | ||
A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the | ||
inner kernel in i+1 iteration. | ||
initial_parameter_value | ||
Paramter to be used by the mcmc_factory before the first iteration. | ||
extra_parameters: | ||
parameters to be used for the creation of the smc_algorithm. | ||
Returns | ||
------- | ||
A ``SamplingAlgorithm``. | ||
""" | ||
|
||
init = staticmethod(init) | ||
build_kernel = staticmethod(build_kernel) | ||
|
||
def __new__( # type: ignore[misc] | ||
cls, | ||
smc_algorithm: Union[adaptive_tempered_smc, tempered_smc], | ||
logprior_fn: Callable, | ||
loglikelihood_fn: Callable, | ||
mcmc_factory: Callable, | ||
mcmc_init_fn: Callable, | ||
mcmc_parameters: Dict, | ||
resampling_fn: Callable, | ||
mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree], | ||
initial_parameter_value, | ||
num_mcmc_steps: int = 10, | ||
**extra_parameters, | ||
) -> SamplingAlgorithm: | ||
kernel = cls.build_kernel( | ||
smc_algorithm, | ||
logprior_fn, | ||
loglikelihood_fn, | ||
mcmc_factory, | ||
mcmc_init_fn, | ||
mcmc_parameters, | ||
resampling_fn, | ||
mcmc_parameter_update_fn, | ||
num_mcmc_steps, | ||
**extra_parameters, | ||
) | ||
|
||
def init_fn(position): | ||
return cls.init(smc_algorithm.init, position, initial_parameter_value) | ||
|
||
def step_fn( | ||
rng_key: PRNGKey, state, **extra_step_parameters | ||
) -> Tuple[StateWithParameterOverride, SMCInfo]: | ||
return kernel(rng_key, state, **extra_step_parameters) | ||
|
||
return SamplingAlgorithm(init_fn, step_fn) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
""" | ||
strategies to tune the parameters of mcmc kernels | ||
used within smc, based on MCMC states | ||
""" | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
__all__ = ["update_scale_from_acceptance_rate"] | ||
|
||
|
||
def update_scale_from_acceptance_rate( | ||
scales: jax.Array, | ||
acceptance_rates: jax.Array, | ||
target_acceptance_rate: float = 0.234, | ||
) -> jax.Array: | ||
""" | ||
Given N chains from some MCMC algorithm like Random Walk Metropolis | ||
and N scale factors, each associated to a different chain. | ||
Updates the scale factors taking into account acceptance rates and | ||
the average acceptance rate. | ||
Under certain assumptions it is known that the optimal acceptance rate | ||
of Metropolis Hastings is 0.4 for 1 dimension and converges to | ||
0.234 in infinite dimensions. In practice, 0.234 is a reasonable | ||
assumption for 5 or more dimensions. | ||
If certain chain is below optimal acceptance rate, its scale will decrease | ||
and if its above, its scale will increase, | ||
------- | ||
Parameters | ||
---------- | ||
scales | ||
(n_chains) array consisting of N scale factors, associated to N markov chains | ||
acceptance_rates | ||
(n_chains) acceptance rate of the N markov chains | ||
target_acceptance_rate | ||
a float with a desirable acceptance rate for the chains. | ||
Returns | ||
------- | ||
(n_chains) new scales, with the aim of getting acceptance rates closer to target | ||
if the chains were to be run again. | ||
""" | ||
chain_scales = jnp.exp(jnp.log(scales) + acceptance_rates - target_acceptance_rate) | ||
return 0.5 * (chain_scales + chain_scales.mean()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
""" | ||
strategies to tune the parameters of mcmc kernels | ||
used within SMC, based on particles. | ||
""" | ||
import jax | ||
import jax.numpy as jnp | ||
from jax._src.flatten_util import ravel_pytree | ||
|
||
from blackjax.types import Array | ||
|
||
__all__ = [ | ||
"particles_means", | ||
"particles_stds", | ||
"particles_covariance_matrix", | ||
"mass_matrix_from_particles", | ||
] | ||
|
||
|
||
def particles_stds(particles): | ||
return jnp.std(particles_as_rows(particles), axis=0) | ||
|
||
|
||
def particles_means(particles): | ||
return jnp.mean(particles_as_rows(particles), axis=0) | ||
|
||
|
||
def particles_covariance_matrix(particles): | ||
return jnp.cov(particles_as_rows(particles), ddof=0, rowvar=False) | ||
|
||
|
||
def mass_matrix_from_particles(particles) -> Array: | ||
""" | ||
Implements tuning from section 3.1 from https://arxiv.org/pdf/1808.07730.pdf | ||
Computing a mass matrix to be used in HMC from particles. | ||
Given the particles covariance matrix, set all non-diagonal elements as zero, | ||
take the inverse, and keep the diagonal. | ||
Returns | ||
------- | ||
A mass Matrix | ||
""" | ||
return jnp.diag(1.0 / jnp.var(particles_as_rows(particles), axis=0)) | ||
|
||
|
||
def particles_as_rows(particles): | ||
""" | ||
Adds end dimension for single-dimension variables, and then represents multivariables | ||
as a matrix where each column is a variable, each row a particle. | ||
""" | ||
return jax.vmap(lambda x: ravel_pytree(x)[0])(particles) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import chex | ||
import jax.numpy as jnp | ||
import jax.scipy.stats as stats | ||
import numpy as np | ||
|
||
|
||
class SMCLinearRegressionTestCase(chex.TestCase): | ||
def logdensity_fn(self, log_scale, coefs, preds, x): | ||
"""Linear regression""" | ||
scale = jnp.exp(log_scale) | ||
y = jnp.dot(x, coefs) | ||
logpdf = stats.norm.logpdf(preds, y, scale) | ||
return jnp.sum(logpdf) | ||
|
||
def particles_prior_loglikelihood(self): | ||
num_particles = 100 | ||
|
||
x_data = np.random.normal(0, 1, size=(1000, 1)) | ||
y_data = 3 * x_data + np.random.normal(size=x_data.shape) | ||
observations = {"x": x_data, "preds": y_data} | ||
|
||
logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf( | ||
x["coefs"] | ||
) | ||
loglikelihood_fn = lambda x: self.logdensity_fn(**x, **observations) | ||
|
||
log_scale_init = np.random.randn(num_particles) | ||
coeffs_init = np.random.randn(num_particles) | ||
init_particles = {"log_scale": log_scale_init, "coefs": coeffs_init} | ||
|
||
return init_particles, logprior_fn, loglikelihood_fn | ||
|
||
def assert_linear_regression_test_case(self, result): | ||
np.testing.assert_allclose( | ||
np.mean(np.exp(result.particles["log_scale"])), 1.0, rtol=1e-1 | ||
) | ||
np.testing.assert_allclose(np.mean(result.particles["coefs"]), 3.0, rtol=1e-1) |
Oops, something went wrong.