-
Notifications
You must be signed in to change notification settings - Fork 107
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SMC-MCMC integration test, plus fixes. (#522)
* fixing interfaces, including test * improving test speed, fixing other tests * adding mala and nuts to SMC-MCMC integration test
- Loading branch information
1 parent
d81d999
commit 9b2ddaf
Showing
4 changed files
with
118 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import unittest | ||
|
||
import jax | ||
from jax import numpy as jnp | ||
from jax.scipy.stats import multivariate_normal | ||
|
||
import blackjax | ||
from blackjax import adaptive_tempered_smc | ||
from blackjax.mcmc.random_walk import normal | ||
|
||
|
||
class SMCAndMCMCIntegrationTest(unittest.TestCase): | ||
""" | ||
An integration test that verifies which MCMC can be used as | ||
SMC mutation step kernels. | ||
""" | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self.key = jax.random.PRNGKey(42) | ||
self.initial_particles = jax.random.multivariate_normal( | ||
self.key, jnp.zeros(2), jnp.eye(2), (3,) | ||
) | ||
|
||
def check_compatible(self, mcmc_step_fn, mcmc_init_fn, mcmc_parameters): | ||
""" | ||
Runs one SMC step | ||
""" | ||
init, kernel = adaptive_tempered_smc( | ||
self.prior_log_prob, | ||
self.loglikelihood, | ||
mcmc_step_fn, | ||
mcmc_init_fn, | ||
mcmc_parameters=mcmc_parameters, | ||
resampling_fn=self.resampling_fn, | ||
target_ess=0.5, | ||
root_solver=self.root_solver, | ||
num_mcmc_steps=1, | ||
) | ||
kernel(self.key, init(self.initial_particles)) | ||
|
||
def test_compatible_with_rwm(self): | ||
self.check_compatible( | ||
blackjax.additive_step_random_walk.build_kernel(), | ||
blackjax.additive_step_random_walk.init, | ||
{"random_step": normal(1.0)}, | ||
) | ||
|
||
def test_compatible_with_rmh(self): | ||
self.check_compatible( | ||
blackjax.rmh.build_kernel(), | ||
blackjax.rmh.init, | ||
{ | ||
"transition_generator": lambda a, b: blackjax.mcmc.random_walk.normal( | ||
1.0 | ||
)(a, b) | ||
}, | ||
) | ||
|
||
def test_compatible_with_hmc(self): | ||
self.check_compatible( | ||
blackjax.hmc.kernel(), | ||
blackjax.hmc.init, | ||
{ | ||
"step_size": 0.3, | ||
"inverse_mass_matrix": jnp.array([1]), | ||
"num_integration_steps": 1, | ||
}, | ||
) | ||
|
||
def test_compatible_with_irmh(self): | ||
self.check_compatible( | ||
blackjax.irmh.build_kernel(), | ||
blackjax.irmh.init, | ||
{ | ||
"proposal_distribution": lambda key: jnp.array([1.0, 1.0]) | ||
+ jax.random.normal(key) | ||
}, | ||
) | ||
|
||
def test_compatible_with_nuts(self): | ||
self.check_compatible( | ||
blackjax.nuts.kernel(), | ||
blackjax.nuts.init, | ||
{"step_size": 1e-10, "inverse_mass_matrix": jnp.eye(2)}, | ||
) | ||
|
||
def test_compatible_with_mala(self): | ||
self.check_compatible( | ||
blackjax.mala.kernel(), blackjax.mala.init, {"step_size": 1e-10} | ||
) | ||
|
||
@staticmethod | ||
def prior_log_prob(x): | ||
d = x.shape[0] | ||
return multivariate_normal.logpdf(x, jnp.zeros((d,)), jnp.eye(d)) | ||
|
||
@staticmethod | ||
def loglikelihood(x): | ||
return -5 * jnp.sum(jnp.square(x**2 - 1)) | ||
|
||
@staticmethod | ||
def root_solver(fun, _delta0, min_delta, max_delta, eps=1e-4, max_iter=100): | ||
return 0.8 | ||
|
||
@staticmethod | ||
def resampling_fn(rng_key, weights: jax.Array, num_samples: int): | ||
return jnp.array([0, 1, 2]) |