Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SMC-MCMC integration test, plus fixes. #522

Merged
merged 4 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ def init_fn(position: PyTree):
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, random_step, state, logdensity_fn)
return kernel(rng_key, state, logdensity_fn, random_step)

return MCMCSamplingAlgorithm(init_fn, step_fn)

Expand Down Expand Up @@ -1082,9 +1082,9 @@ def init_fn(position: PyTree):
def step_fn(rng_key: PRNGKey, state):
return kernel(
rng_key,
state,
logdensity_fn,
proposal_generator,
state,
proposal_logdensity_fn,
)

Expand Down Expand Up @@ -1140,7 +1140,7 @@ def init_fn(position: PyTree):
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, proposal_distribution, state, logdensity_fn)
return kernel(rng_key, state, logdensity_fn, proposal_distribution)

return MCMCSamplingAlgorithm(init_fn, step_fn)

Expand Down
10 changes: 5 additions & 5 deletions blackjax/mcmc/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,15 @@ def build_additive_step():
"""

def kernel(
rng_key: PRNGKey, random_step: Callable, state: RWState, logdensity_fn: Callable
rng_key: PRNGKey, state: RWState, logdensity_fn: Callable, random_step: Callable
) -> Tuple[RWState, RWInfo]:
def proposal_generator(key_proposal, position):
move_proposal = random_step(key_proposal, position)
new_position = jax.tree_util.tree_map(jnp.add, position, move_proposal)
return new_position

inner_kernel = build_rmh()
return inner_kernel(rng_key, logdensity_fn, proposal_generator, state)
return inner_kernel(rng_key, state, logdensity_fn, proposal_generator)

return kernel

Expand All @@ -187,9 +187,9 @@ def build_irmh() -> Callable:

def kernel(
rng_key: PRNGKey,
proposal_distribution: Callable,
state: RWState,
logdensity_fn: Callable,
proposal_distribution: Callable,
) -> Tuple[RWState, RWInfo]:
"""
Parameters
Expand All @@ -203,7 +203,7 @@ def proposal_generator(rng_key: PRNGKey, position: PyTree):
return proposal_distribution(rng_key)

inner_kernel = build_rmh()
return inner_kernel(rng_key, logdensity_fn, proposal_generator, state)
return inner_kernel(rng_key, state, logdensity_fn, proposal_generator)

return kernel

Expand All @@ -220,9 +220,9 @@ def build_rmh():

def kernel(
rng_key: PRNGKey,
state: RWState,
logdensity_fn: Callable,
transition_generator: Callable,
state: RWState,
proposal_logdensity_fn: Optional[Callable] = None,
) -> Tuple[RWState, RWInfo]:
"""Move the chain by one step using the Rosenbluth Metropolis Hastings
Expand Down
4 changes: 2 additions & 2 deletions tests/mcmc/test_random_walk_without_chex.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def test_logdensity_accepts(position):

new_state, _ = step(
rng_key,
random_step,
RWState(position=initial_position, logdensity=1.0),
test_logdensity_accepts,
random_step,
)

np.testing.assert_allclose(new_state.position, jnp.array([60.0]))
Expand Down Expand Up @@ -69,9 +69,9 @@ def test_logdensity_accepts(position):
for previous_position in [initial_position, other_position]:
new_state, _ = step(
rng_key,
proposal_distribution,
RWState(position=previous_position, logdensity=1.0),
test_logdensity_accepts,
proposal_distribution,
)
np.testing.assert_allclose(new_state.position, jnp.array([10.0]))

Expand Down
96 changes: 96 additions & 0 deletions tests/smc/test_kernel_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import unittest

import jax
from jax import numpy as jnp
from jax.scipy.stats import multivariate_normal

import blackjax
from blackjax import adaptive_tempered_smc
from blackjax.mcmc.random_walk import normal


class SMCAndMCMCIntegrationTest(unittest.TestCase):
"""
An integration test that verifies which MCMC can be used as
SMC mutation step kernels.
"""

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)
self.initial_particles = jax.random.multivariate_normal(
self.key, jnp.zeros(2), jnp.eye(2), (3,)
)

def check_compatible(self, mcmc_step_fn, mcmc_init_fn, mcmc_parameters):
"""
Runs one SMC step
"""
init, kernel = adaptive_tempered_smc(
self.prior_log_prob,
self.loglikelihood,
mcmc_step_fn,
mcmc_init_fn,
mcmc_parameters=mcmc_parameters,
resampling_fn=self.resampling_fn,
target_ess=0.5,
root_solver=self.root_solver,
num_mcmc_steps=1,
)
kernel(self.key, init(self.initial_particles))

def test_compatible_with_rwm(self):
self.check_compatible(
blackjax.additive_step_random_walk.build_kernel(),
blackjax.additive_step_random_walk.init,
{"random_step": normal(1.0)},
)

def test_compatible_with_rmh(self):
self.check_compatible(
blackjax.rmh.build_kernel(),
blackjax.rmh.init,
{
"transition_generator": lambda a, b: blackjax.mcmc.random_walk.normal(
1.0
)(a, b)
},
)

def test_compatible_with_hmc(self):
self.check_compatible(
blackjax.hmc.kernel(),
blackjax.hmc.init,
{
"step_size": 0.3,
"inverse_mass_matrix": jnp.array([1]),
"num_integration_steps": 1,
},
)

def test_compatible_with_irmh(self):
self.check_compatible(
blackjax.irmh.build_kernel(),
blackjax.irmh.init,
{
"proposal_distribution": lambda key: jnp.array([1.0, 1.0])
+ jax.random.normal(key)
},
)

@staticmethod
def prior_log_prob(x):
d = x.shape[0]
return multivariate_normal.logpdf(x, jnp.zeros((d,)), jnp.eye(d))

@staticmethod
def loglikelihood(x):
return -5 * jnp.sum(jnp.square(x**2 - 1))

@staticmethod
def root_solver(fun, _delta0, min_delta, max_delta, eps=1e-4, max_iter=100):
return 0.8

@staticmethod
def resampling_fn(rng_key, weights: jax.Array, num_samples: int):
return jnp.array([0, 1, 2])