Skip to content

Commit

Permalink
SMC-MCMC integration test, plus fixes. (#522)
Browse files Browse the repository at this point in the history
* fixing interfaces, including test

* improving test speed, fixing other tests

* adding mala and nuts to SMC-MCMC integration test
  • Loading branch information
ciguaran authored Apr 20, 2023
1 parent a650f9b commit 9bb5f11
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 10 deletions.
6 changes: 3 additions & 3 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,7 @@ def init_fn(position: PyTree):
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, random_step, state, logdensity_fn)
return kernel(rng_key, state, logdensity_fn, random_step)

return MCMCSamplingAlgorithm(init_fn, step_fn)

Expand Down Expand Up @@ -1085,9 +1085,9 @@ def init_fn(position: PyTree):
def step_fn(rng_key: PRNGKey, state):
return kernel(
rng_key,
state,
logdensity_fn,
proposal_generator,
state,
proposal_logdensity_fn,
)

Expand Down Expand Up @@ -1143,7 +1143,7 @@ def init_fn(position: PyTree):
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, proposal_distribution, state, logdensity_fn)
return kernel(rng_key, state, logdensity_fn, proposal_distribution)

return MCMCSamplingAlgorithm(init_fn, step_fn)

Expand Down
10 changes: 5 additions & 5 deletions blackjax/mcmc/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,15 @@ def build_additive_step():
"""

def kernel(
rng_key: PRNGKey, random_step: Callable, state: RWState, logdensity_fn: Callable
rng_key: PRNGKey, state: RWState, logdensity_fn: Callable, random_step: Callable
) -> Tuple[RWState, RWInfo]:
def proposal_generator(key_proposal, position):
move_proposal = random_step(key_proposal, position)
new_position = jax.tree_util.tree_map(jnp.add, position, move_proposal)
return new_position

inner_kernel = build_rmh()
return inner_kernel(rng_key, logdensity_fn, proposal_generator, state)
return inner_kernel(rng_key, state, logdensity_fn, proposal_generator)

return kernel

Expand All @@ -187,9 +187,9 @@ def build_irmh() -> Callable:

def kernel(
rng_key: PRNGKey,
proposal_distribution: Callable,
state: RWState,
logdensity_fn: Callable,
proposal_distribution: Callable,
) -> Tuple[RWState, RWInfo]:
"""
Parameters
Expand All @@ -203,7 +203,7 @@ def proposal_generator(rng_key: PRNGKey, position: PyTree):
return proposal_distribution(rng_key)

inner_kernel = build_rmh()
return inner_kernel(rng_key, logdensity_fn, proposal_generator, state)
return inner_kernel(rng_key, state, logdensity_fn, proposal_generator)

return kernel

Expand All @@ -220,9 +220,9 @@ def build_rmh():

def kernel(
rng_key: PRNGKey,
state: RWState,
logdensity_fn: Callable,
transition_generator: Callable,
state: RWState,
proposal_logdensity_fn: Optional[Callable] = None,
) -> Tuple[RWState, RWInfo]:
"""Move the chain by one step using the Rosenbluth Metropolis Hastings
Expand Down
4 changes: 2 additions & 2 deletions tests/mcmc/test_random_walk_without_chex.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def test_logdensity_accepts(position):

new_state, _ = step(
rng_key,
random_step,
RWState(position=initial_position, logdensity=1.0),
test_logdensity_accepts,
random_step,
)

np.testing.assert_allclose(new_state.position, jnp.array([60.0]))
Expand Down Expand Up @@ -69,9 +69,9 @@ def test_logdensity_accepts(position):
for previous_position in [initial_position, other_position]:
new_state, _ = step(
rng_key,
proposal_distribution,
RWState(position=previous_position, logdensity=1.0),
test_logdensity_accepts,
proposal_distribution,
)
np.testing.assert_allclose(new_state.position, jnp.array([10.0]))

Expand Down
108 changes: 108 additions & 0 deletions tests/smc/test_kernel_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import unittest

import jax
from jax import numpy as jnp
from jax.scipy.stats import multivariate_normal

import blackjax
from blackjax import adaptive_tempered_smc
from blackjax.mcmc.random_walk import normal


class SMCAndMCMCIntegrationTest(unittest.TestCase):
"""
An integration test that verifies which MCMC can be used as
SMC mutation step kernels.
"""

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)
self.initial_particles = jax.random.multivariate_normal(
self.key, jnp.zeros(2), jnp.eye(2), (3,)
)

def check_compatible(self, mcmc_step_fn, mcmc_init_fn, mcmc_parameters):
"""
Runs one SMC step
"""
init, kernel = adaptive_tempered_smc(
self.prior_log_prob,
self.loglikelihood,
mcmc_step_fn,
mcmc_init_fn,
mcmc_parameters=mcmc_parameters,
resampling_fn=self.resampling_fn,
target_ess=0.5,
root_solver=self.root_solver,
num_mcmc_steps=1,
)
kernel(self.key, init(self.initial_particles))

def test_compatible_with_rwm(self):
self.check_compatible(
blackjax.additive_step_random_walk.build_kernel(),
blackjax.additive_step_random_walk.init,
{"random_step": normal(1.0)},
)

def test_compatible_with_rmh(self):
self.check_compatible(
blackjax.rmh.build_kernel(),
blackjax.rmh.init,
{
"transition_generator": lambda a, b: blackjax.mcmc.random_walk.normal(
1.0
)(a, b)
},
)

def test_compatible_with_hmc(self):
self.check_compatible(
blackjax.hmc.kernel(),
blackjax.hmc.init,
{
"step_size": 0.3,
"inverse_mass_matrix": jnp.array([1]),
"num_integration_steps": 1,
},
)

def test_compatible_with_irmh(self):
self.check_compatible(
blackjax.irmh.build_kernel(),
blackjax.irmh.init,
{
"proposal_distribution": lambda key: jnp.array([1.0, 1.0])
+ jax.random.normal(key)
},
)

def test_compatible_with_nuts(self):
self.check_compatible(
blackjax.nuts.kernel(),
blackjax.nuts.init,
{"step_size": 1e-10, "inverse_mass_matrix": jnp.eye(2)},
)

def test_compatible_with_mala(self):
self.check_compatible(
blackjax.mala.kernel(), blackjax.mala.init, {"step_size": 1e-10}
)

@staticmethod
def prior_log_prob(x):
d = x.shape[0]
return multivariate_normal.logpdf(x, jnp.zeros((d,)), jnp.eye(d))

@staticmethod
def loglikelihood(x):
return -5 * jnp.sum(jnp.square(x**2 - 1))

@staticmethod
def root_solver(fun, _delta0, min_delta, max_delta, eps=1e-4, max_iter=100):
return 0.8

@staticmethod
def resampling_fn(rng_key, weights: jax.Array, num_samples: int):
return jnp.array([0, 1, 2])

0 comments on commit 9bb5f11

Please sign in to comment.