From b107f9fd60cfc1261a5ce35690b1d0f141041c07 Mon Sep 17 00:00:00 2001 From: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com> Date: Tue, 8 Oct 2024 12:51:38 -0500 Subject: [PATCH] Add pre-conditioning matrix to Barker proposal (#731) * Draft pre-conditioning matrix in Barker proposal. This is a first draft of adding the pre-conditioning to the Barker proposal. This follows Algorithms 4 and 5 in Appendix G of the original Barker proposal paper. It's somewhat unclear from the paper, but the separate step size that was already implemented serves as a global scale for the normal distribution of the proposal. The function `_compute_acceptance_probability` now takes in the transpose sqrt mass matrix and the inverse, also it has been flattened to accomodate the corresponding matrix multiplicatios. * Fix typing of inverse_mass_matrix argument Fix typing of mass matrix. * Fix docstrings. The original docstring of step_size was incorrect, there is no sympletic integrator. * Make test for Barker in test_sampling run again We make this possible by adding an identity pre-conditining matrix, which should make the test run in the same way as before. * Add test to ensure correctness of precond matrix We add a new test to barker.py to ensure that our implementation of the preconditioning matrix is correct. We follow Appendix G in the paper that mentions that algorithm 4 and 5 (which we implemented) should be equivalent to rescaling the parameters and the logdensity in a specific way. We implement both approaches when using the barker proposal to infer the mean and sigma of a normal distribution. We check that with two different random seeds the chains outputted are equivalent up to some tolerance. We also patch the original test in this file by adding an identity mass matrix. * Fix dimensionality of identity matrix * Add missing mass matrix in missing tests. * added option to transpose the matrix when scaling option to transpose the mass_matrix_sqrt or inv_mass_matrix_sqrt was necessary for the barker algorithm as far as I can tell. This has not been propagated to the riemannian metric * use the metric scaling function in barker Here we use the new metric.scale function to perform the operations required by the Barker proposal algorithm, instead of passing around the mass_matrix_sqrt and inv_mass_matrix_sqrt directly. We also make the `inverse_mass_matrix` argument optional to avoid breaking the API. * update test_sampling with barker api the mass matrix is now an optional argument in barker. * update test_barker so it works with metric.scale * fix tests add trans to scale * add trans argument to riemannian scaling * no default * Update barker.py Make acceptance function metric agnostic * Update test_barker.py Add invariance test * simplify logic to remove _barker_sample_nd * fix bug so now everything is tree_mapped in barker * fix test to not use _barker_sample_nd * Update blackjax/mcmc/metrics.py make inv and trans required kwarg with type bool in metric.scale Co-authored-by: Junpeng Lao * Update blackjax/mcmc/metrics.py lax.cond might not be needed in metric.scale as inv and trans are static kwarg Co-authored-by: Junpeng Lao * propagate changes of inv, trans as required kwarg * fix test metrics --------- Co-authored-by: Adrien Corenflos Co-authored-by: Junpeng Lao --- blackjax/mcmc/barker.py | 146 +++++++++++++++++++----------------- blackjax/mcmc/metrics.py | 55 ++++++++++---- tests/mcmc/test_barker.py | 128 ++++++++++++++++++++++++++++++- tests/mcmc/test_metrics.py | 32 ++++++-- tests/mcmc/test_sampling.py | 1 + 5 files changed, 269 insertions(+), 93 deletions(-) diff --git a/blackjax/mcmc/barker.py b/blackjax/mcmc/barker.py index 9923bd5f3..7ae7d2463 100644 --- a/blackjax/mcmc/barker.py +++ b/blackjax/mcmc/barker.py @@ -18,11 +18,13 @@ import jax.numpy as jnp from jax.flatten_util import ravel_pytree from jax.scipy import stats -from jax.tree_util import tree_leaves, tree_map +import blackjax.mcmc.metrics as metrics from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.metrics import Metric from blackjax.mcmc.proposal import static_binomial_sampling -from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.types import ArrayLikeTree, ArrayTree, Numeric, PRNGKey +from blackjax.util import generate_gaussian_noise __all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "as_top_level_api"] @@ -81,44 +83,70 @@ def build_kernel(): """ def _compute_acceptance_probability( - state: BarkerState, - proposal: BarkerState, - ) -> float: + state: BarkerState, proposal: BarkerState, metric: Metric + ) -> Numeric: """Compute the acceptance probability of the Barker's proposal kernel.""" - def ratio_proposal_nd(y, x, log_y, log_x): - num = -_log1pexp(-log_y * (x - y)) - den = -_log1pexp(-log_x * (y - x)) + x = state.position + y = proposal.position + log_x = state.logdensity_grad + log_y = proposal.logdensity_grad - return jnp.sum(num - den) + y_minus_x = jax.tree_util.tree_map(lambda a, b: a - b, y, x) + x_minus_y = jax.tree_util.tree_map(lambda a: -a, y_minus_x) + z_tilde_x_to_y = metric.scale(x, y_minus_x, inv=True, trans=True) + z_tilde_y_to_x = metric.scale(y, x_minus_y, inv=True, trans=True) - ratios_proposals = tree_map( - ratio_proposal_nd, - proposal.position, - state.position, - proposal.logdensity_grad, - state.logdensity_grad, + c_x_to_y = metric.scale(x, log_x, inv=False, trans=True) + c_y_to_x = metric.scale(y, log_y, inv=False, trans=True) + + z_tilde_x_to_y_flat, _ = ravel_pytree(z_tilde_x_to_y) + z_tilde_y_to_x_flat, _ = ravel_pytree(z_tilde_y_to_x) + + c_x_to_y_flat, _ = ravel_pytree(c_x_to_y) + c_y_to_x_flat, _ = ravel_pytree(c_y_to_x) + + num = metric.kinetic_energy(x_minus_y, y) - _log1pexp( + -z_tilde_y_to_x_flat * c_y_to_x_flat ) - ratio_proposal = sum(tree_leaves(ratios_proposals)) + denom = metric.kinetic_energy(y_minus_x, x) - _log1pexp( + -z_tilde_x_to_y_flat * c_x_to_y_flat + ) + + ratio_proposal = jnp.sum(num - denom) + return proposal.logdensity - state.logdensity + ratio_proposal def kernel( - rng_key: PRNGKey, state: BarkerState, logdensity_fn: Callable, step_size: float + rng_key: PRNGKey, + state: BarkerState, + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: metrics.MetricTypes | None = None, ) -> tuple[BarkerState, BarkerInfo]: - """Generate a new sample with the MALA kernel.""" + """Generate a new sample with the Barker kernel.""" + if inverse_mass_matrix is None: + p, _ = ravel_pytree(state.position) + (m,) = p.shape + inverse_mass_matrix = jnp.ones((m,)) + metric = metrics.default_metric(inverse_mass_matrix) grad_fn = jax.value_and_grad(logdensity_fn) - key_sample, key_rmh = jax.random.split(rng_key) proposed_pos = _barker_sample( - key_sample, state.position, state.logdensity_grad, step_size + key_sample, + state.position, + state.logdensity_grad, + step_size, + metric, ) + proposed_logdensity, proposed_logdensity_grad = grad_fn(proposed_pos) proposed_state = BarkerState( proposed_pos, proposed_logdensity, proposed_logdensity_grad ) - log_p_accept = _compute_acceptance_probability(state, proposed_state) + log_p_accept = _compute_acceptance_probability(state, proposed_state, metric) accepted_state, info = static_binomial_sampling( key_rmh, log_p_accept, state, proposed_state ) @@ -131,6 +159,7 @@ def kernel( def as_top_level_api( logdensity_fn: Callable, step_size: float, + inverse_mass_matrix: metrics.MetricTypes | None = None, ) -> SamplingAlgorithm: """Implements the (basic) user interface for the Barker's proposal :cite:p:`Livingstone2022Barker` kernel with a Gaussian base kernel. @@ -174,7 +203,9 @@ def as_top_level_api( logdensity_fn The log-density function we wish to draw samples from. step_size - The value to use for the step size in the symplectic integrator. + The value of the step_size correspnoding to the global scale of the proposal distribution. + inverse_mass_matrix + The inverse mass matrix to use for pre-conditioning (see Appendix G of :cite:p:`Livingstone2022Barker`). Returns ------- @@ -189,74 +220,55 @@ def init_fn(position: ArrayLikeTree, rng_key=None): return init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): - return kernel(rng_key, state, logdensity_fn, step_size) + return kernel(rng_key, state, logdensity_fn, step_size, inverse_mass_matrix) return SamplingAlgorithm(init_fn, step_fn) -def _barker_sample_nd(key, mean, a, scale): - """ - Sample from a multivariate Barker's proposal distribution. In 1D, this has the following probability density function: - - .. math:: - p(x; \\mu, a, \\sigma) = 2 \frac{N(x; \\mu, \\sigma^2)}{1 + \\exp(-a (x - \\mu)} +def _generate_bernoulli( + rng_key: PRNGKey, position: ArrayLikeTree, p: ArrayLikeTree +) -> ArrayTree: + pos, unravel_fn = ravel_pytree(position) + p_flat, _ = ravel_pytree(p) + sample = jax.random.bernoulli(rng_key, p=p_flat, shape=pos.shape) + return unravel_fn(sample) - where :math:`N(x; \\mu, \\sigma^2)` is the normal distribution with mean :math:`\\mu` and standard deviation :math:`\\sigma`. - The multivariate Barker's proposal distribution is the product of one-dimensional Barker's proposal distributions. +def _barker_sample(key, mean, a, scale, metric): + r""" + Sample from a multivariate Barker's proposal distribution for PyTrees. Parameters ---------- key A PRNG key. mean - The mean of the normal distribution, an Array. This corresponds to :math:`\\mu` in the equation above. + The mean of the normal distribution, a PyTree. This corresponds to :math:`\mu` in the equation above. a - The parameter :math:`a` in the equation above, an Array. This is a skewness parameter. + The parameter :math:`a` in the equation above, the same PyTree as `mean`. This is a skewness parameter. scale - The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\\sigma` in the equation above. + The global scale, a scalar. This corresponds to :math:`\\sigma` in the equation above. It encodes the step size of the proposal. - - Returns - ------- - A sample from the Barker's multidimensional proposal distribution. - + metric + A `metrics.MetricTypes` object encoding the mass matrix information. """ key1, key2 = jax.random.split(key) - z = scale * jax.random.normal(key1, shape=mean.shape) + + z = generate_gaussian_noise(key1, mean, sigma=scale) + c = metric.scale(mean, a, inv=False, trans=True) # Sample b=1 with probability p and 0 with probability 1 - p where # p = 1 / (1 + exp(-a * (z - mean))) - log_p = -_log1pexp(-a * z) - b = jax.random.bernoulli(key2, p=jnp.exp(log_p), shape=mean.shape) - - # return mean + z if b == 1 else mean - z - return mean + b * z - (1 - b) * z - + log_p = jax.tree_util.tree_map(lambda x, y: -_log1pexp(-x * y), c, z) + p = jax.tree_util.tree_map(lambda x: jnp.exp(x), log_p) + b = _generate_bernoulli(key2, mean, p=p) -def _barker_sample(key, mean, a, scale): - r""" - Sample from a multivariate Barker's proposal distribution for PyTrees. - - Parameters - ---------- - key - A PRNG key. - mean - The mean of the normal distribution, a PyTree. This corresponds to :math:`\mu` in the equation above. - a - The parameter :math:`a` in the equation above, the same PyTree as `mean`. This is a skewness parameter. - scale - The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\sigma` in the equation above. - It encodes the step size of the proposal. - - """ + bz = jax.tree_util.tree_map(lambda x, y: x * y - (1 - x) * y, b, z) - flat_mean, unravel_fn = ravel_pytree(mean) - flat_a, _ = ravel_pytree(a) - flat_sample = _barker_sample_nd(key, flat_mean, flat_a, scale) - return unravel_fn(flat_sample) + return jax.tree_util.tree_map( + lambda a, b: a + b, mean, metric.scale(mean, bz, inv=False, trans=False) + ) def _log1pexp(a): diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index 4e079714b..f0720acf4 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -30,7 +30,6 @@ """ from typing import Callable, NamedTuple, Optional, Protocol, Union -import jax import jax.numpy as jnp import jax.scipy as jscipy from jax.flatten_util import ravel_pytree @@ -62,7 +61,12 @@ def __call__( class Scale(Protocol): def __call__( - self, position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree + self, + position: ArrayLikeTree, + element: ArrayLikeTree, + *, + inv: bool, + trans: bool, ) -> ArrayLikeTree: ... @@ -187,7 +191,11 @@ def is_turning( return turning_at_left | turning_at_right def scale( - position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree + position: ArrayLikeTree, + element: ArrayLikeTree, + *, + inv: bool, + trans: bool, ) -> ArrayLikeTree: """Scale elements by the mass matrix. @@ -197,10 +205,11 @@ def scale( The current position. Not used in this metric. elements Elements to scale - invs + inv Whether to scale the elements by the inverse mass matrix or the mass matrix. If True, the element is scaled by the inverse square root mass matrix, i.e., elem <- (M^{1/2})^{-1} elem. - Same pytree structure as `elements`. + trans + whether to transpose mass matrix when scaling Returns ------- @@ -209,11 +218,16 @@ def scale( """ ravelled_element, unravel_fn = ravel_pytree(element) - scaled = jax.lax.cond( - inv, - lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element), - lambda: linear_map(mass_matrix_sqrt, ravelled_element), - ) + + if inv: + left_hand_side_matrix = inv_mass_matrix_sqrt + else: + left_hand_side_matrix = mass_matrix_sqrt + if trans: + left_hand_side_matrix = left_hand_side_matrix.T + + scaled = linear_map(left_hand_side_matrix, ravelled_element) + return unravel_fn(scaled) return Metric(momentum_generator, kinetic_energy, is_turning, scale) @@ -279,7 +293,11 @@ def is_turning( # return turning_at_left | turning_at_right def scale( - position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree + position: ArrayLikeTree, + element: ArrayLikeTree, + *, + inv: bool, + trans: bool, ) -> ArrayLikeTree: """Scale elements by the mass matrix. @@ -298,11 +316,16 @@ def scale( mass_matrix, is_inv=False ) ravelled_element, unravel_fn = ravel_pytree(element) - scaled = jax.lax.cond( - inv, - lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element), - lambda: linear_map(mass_matrix_sqrt, ravelled_element), - ) + + if inv: + left_hand_side_matrix = inv_mass_matrix_sqrt + else: + left_hand_side_matrix = mass_matrix_sqrt + if trans: + left_hand_side_matrix = left_hand_side_matrix.T + + scaled = linear_map(left_hand_side_matrix, ravelled_element) + return unravel_fn(scaled) return Metric(momentum_generator, kinetic_energy, is_turning, scale) diff --git a/tests/mcmc/test_barker.py b/tests/mcmc/test_barker.py index 5c227c4cb..04a86d1d4 100644 --- a/tests/mcmc/test_barker.py +++ b/tests/mcmc/test_barker.py @@ -1,9 +1,16 @@ +import functools +import itertools + import chex import jax import jax.numpy as jnp +import jax.scipy.stats as stats from absl.testing import absltest, parameterized -from blackjax.mcmc.barker import _barker_pdf, _barker_sample_nd +import blackjax +from blackjax.mcmc import metrics +from blackjax.mcmc.barker import _barker_pdf, _barker_sample +from blackjax.util import run_inference_algorithm class BarkerSamplingTest(chex.TestCase): @@ -18,8 +25,9 @@ def test_nd(self, seed): 0.5, ) + metric = metrics.default_metric(jnp.eye(4)) keys = jax.random.split(key, n_samples) - samples = jax.vmap(lambda k: _barker_sample_nd(k, m, a, scale))(keys) + samples = jax.vmap(lambda k: _barker_sample(k, m, a, scale, metric))(keys) # Check that the emprical mean and the mean computed as sum(x * p(x) dx) are close _test_samples_vs_pdf(samples, lambda x: _barker_pdf(x, m, a, scale)) @@ -51,5 +59,121 @@ def _test_samples_vs_pdf(samples, pdf): ) +class BarkerPreconditioiningTest(chex.TestCase): + @parameterized.parameters([1234, 5678]) + def test_preconditioning_matrix(self, seed): + """Test two different ways of using pre-conditioning matrix has exactly same effect. + + We follow the discussion in Appendix G of the Barker 2020 paper. + """ + + key = jax.random.key(seed) + init_key, inference_key = jax.random.split(key, 2) + + # setup some 2D multivariate normal model + # setup sampling mean and cov + true_x = jnp.array([0.0, 1.0]) + data = jax.random.normal(init_key, shape=(1000,)) * true_x[1] + true_x[0] + assert data.shape == (1000,) + + # some non-diagonal positive-defininte matrix for pre-conditioning + inv_mass_matrix = jnp.array([[1, 0.1], [0.1, 1]]) + metric = metrics.default_metric(inv_mass_matrix) + + # define barker kernel two ways + # non-scaled, use pre-conditioning + def logdensity(x, data): + mu_prior = stats.norm.logpdf(x[0], loc=0, scale=1) + sigma_prior = stats.uniform.logpdf(x[1], 0.0, 3.0) + return mu_prior + sigma_prior + jnp.sum(stats.norm.logcdf(data, x[0], x[1])) + + logposterior_fn1 = functools.partial(logdensity, data=data) + barker1 = blackjax.barker_proposal(logposterior_fn1, 1e-1, inv_mass_matrix) + state1 = barker1.init(true_x) + + # scaled, trivial pre-conditioning + def scaled_logdensity(x_scaled, data, metric): + x = metric.scale(x_scaled, x_scaled, inv=False, trans=False) + return logdensity(x, data) + + logposterior_fn2 = functools.partial( + scaled_logdensity, data=data, metric=metric + ) + barker2 = blackjax.barker_proposal(logposterior_fn2, 1e-1, jnp.eye(2)) + + true_x_trans = metric.scale(true_x, true_x, inv=True, trans=True) + state2 = barker2.init(true_x_trans) + + n_steps = 10 + _, states1 = run_inference_algorithm( + rng_key=inference_key, + initial_state=state1, + inference_algorithm=barker1, + transform=lambda state, info: state.position, + num_steps=n_steps, + ) + + _, states2 = run_inference_algorithm( + rng_key=inference_key, + initial_state=state2, + inference_algorithm=barker2, + transform=lambda state, info: state.position, + num_steps=n_steps, + ) + + # states should be the exact same with same random key after transforming + states2_trans = [] + for ii in range(n_steps): + s = states2[ii] + states2_trans.append(metric.scale(s, s, inv=False, trans=False)) + states2_trans = jnp.array(states2_trans) + assert jnp.allclose(states1, states2_trans) + + @parameterized.parameters( + itertools.product([1234, 5678], ["gaussian", "riemannian"]) + ) + def test_invariance(self, seed, metric): + logpdf = lambda x: -0.5 * jnp.sum(x**2) + + n_samples, m_steps = 10_000, 50 + + key = jax.random.key(seed) + init_key, inference_key = jax.random.split(key, 2) + inference_keys = jax.random.split(inference_key, n_samples) + if metric == "gaussian": + inv_mass_matrix = jnp.ones((2,)) + metric = metrics.default_metric(inv_mass_matrix) + else: + # bit of a random metric but we are testing invariance, not efficiency + metric = metrics.gaussian_riemannian( + lambda x: 1 / jnp.sum(1 + jnp.sum(x**2)) * jnp.eye(2) + ) + + barker = blackjax.barker_proposal(logpdf, 0.5, metric) + init_samples = jax.random.normal(init_key, shape=(n_samples, 2)) + + def loop(carry, key_): + state, accepted = carry + state, info = barker.step(key_, state) + accepted += info.is_accepted + return (state, accepted), None + + def get_samples(init_sample, key_): + init = (barker.init(init_sample), 0) + (out, n_accepted), _ = jax.lax.scan( + loop, init, jax.random.split(key_, m_steps) + ) + return out.position, n_accepted / m_steps + + samples, total_accepted = jax.vmap(get_samples)(init_samples, inference_keys) + # now we test the distance versus a Gaussian + chex.assert_trees_all_close( + jnp.mean(samples, 0), jnp.zeros((2,)), atol=1e-1, rtol=1e-1 + ) + chex.assert_trees_all_close( + jnp.cov(samples.T), jnp.eye(2), atol=1e-1, rtol=1e-1 + ) + + if __name__ == "__main__": absltest.main() diff --git a/tests/mcmc/test_metrics.py b/tests/mcmc/test_metrics.py index 0791f3cb1..e6aa5879f 100644 --- a/tests/mcmc/test_metrics.py +++ b/tests/mcmc/test_metrics.py @@ -131,8 +131,12 @@ def test_gaussian_euclidean_dim_1(self): assert momentum_val == expected_momentum_val assert kinetic_energy_val == expected_kinetic_energy_val - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) - scaled_momentum = scale(arbitrary_position, momentum_val, False) + inv_scaled_momentum = scale( + arbitrary_position, momentum_val, inv=True, trans=False + ) + scaled_momentum = scale( + arbitrary_position, momentum_val, inv=False, trans=False + ) expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix) expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix) @@ -164,8 +168,12 @@ def test_gaussian_euclidean_dim_2(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) - scaled_momentum = scale(arbitrary_position, momentum_val, False) + inv_scaled_momentum = scale( + arbitrary_position, momentum_val, inv=True, trans=False + ) + scaled_momentum = scale( + arbitrary_position, momentum_val, inv=False, trans=False + ) expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val expected_scaled_momentum = L_inv @ momentum_val @@ -226,8 +234,12 @@ def test_gaussian_riemannian_dim_1(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) - scaled_momentum = scale(arbitrary_position, momentum_val, False) + inv_scaled_momentum = scale( + arbitrary_position, momentum_val, inv=True, trans=False + ) + scaled_momentum = scale( + arbitrary_position, momentum_val, inv=False, trans=False + ) expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix) expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix) @@ -265,8 +277,12 @@ def test_gaussian_riemannian_dim_2(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) - scaled_momentum = scale(arbitrary_position, momentum_val, False) + inv_scaled_momentum = scale( + arbitrary_position, momentum_val, inv=True, trans=False + ) + scaled_momentum = scale( + arbitrary_position, momentum_val, inv=False, trans=False + ) expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val expected_scaled_momentum = L_inv @ momentum_val diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index c399929da..98572cabc 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -1,4 +1,5 @@ """Test the accuracy of the MCMC kernels.""" + import functools import itertools