diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 512c6d024..a0dcaecd9 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -1,6 +1,7 @@ from blackjax._version import __version__ from .adaptation.chees_adaptation import chees_adaptation +from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size from .adaptation.meads_adaptation import meads_adaptation from .adaptation.pathfinder_adaptation import pathfinder_adaptation from .adaptation.window_adaptation import window_adaptation @@ -12,6 +13,7 @@ from .mcmc.hmc import dynamic_hmc, hmc from .mcmc.mala import mala from .mcmc.marginal_latent_gaussian import mgrad_gaussian +from .mcmc.mclmc import mclmc from .mcmc.nuts import nuts from .mcmc.periodic_orbital import orbital_hmc from .mcmc.random_walk import additive_step_random_walk, irmh, rmh @@ -40,6 +42,7 @@ "additive_step_random_walk", "rmh", "irmh", + "mclmc", "elliptical_slice", "ghmc", "barker_proposal", @@ -51,6 +54,7 @@ "meads_adaptation", "chees_adaptation", "pathfinder_adaptation", + "mclmc_find_L_and_step_size", # mclmc adaptation "adaptive_tempered_smc", # smc "tempered_smc", "meanfield_vi", # variational inference diff --git a/blackjax/adaptation/__init__.py b/blackjax/adaptation/__init__.py index 91a491ed0..53d5fe2b6 100644 --- a/blackjax/adaptation/__init__.py +++ b/blackjax/adaptation/__init__.py @@ -1,5 +1,6 @@ from . import ( chees_adaptation, + mclmc_adaptation, meads_adaptation, pathfinder_adaptation, window_adaptation, @@ -10,4 +11,5 @@ "meads_adaptation", "window_adaptation", "pathfinder_adaptation", + "mclmc_adaptation", ] diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py new file mode 100644 index 000000000..44a2944fc --- /dev/null +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -0,0 +1,280 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Algorithms to adapt the MCLMC kernel parameters, namely step size and L. + +""" + +from typing import NamedTuple + +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree + +from blackjax.diagnostics import effective_sample_size # type: ignore +from blackjax.util import pytree_size + + +class MCLMCAdaptationState(NamedTuple): + """Represents the tunable parameters for MCLMC adaptation. + + Attributes: + L (float): The momentum decoherent rate for the MCLMC algorithm. + step_size (float): The step size used for the MCLMC algorithm. + """ + + L: float + step_size: float + + +def mclmc_find_L_and_step_size( + mclmc_kernel, + num_steps, + state, + rng_key, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1, + desired_energy_var=5e-4, + trust_in_estimate=1.5, + num_effective_samples=150, +): + """ + Finds the optimal value of the parameters for the MCLMC algorithm. + + Args: + mclmc_kernel (callable): The kernel function used for the MCMC algorithm. + num_steps (int): The number of MCMC steps that will subsequently be run, after tuning. + state (MCMCState): The initial state of the MCMC algorithm. + rng_key (jax.random.PRNGKey): The random number generator key. + frac_tune1 (float): The fraction of tuning for the first step of the adaptation. + frac_tune2 (float): The fraction of tuning for the second step of the adaptation. + frac_tune3 (float): The fraction of tuning for the third step of the adaptation. + desired_energy_var (float): The desired energy variance for the MCMC algorithm. + trust_in_estimate (float): The trust in the estimate of optimal stepsize. + num_effective_samples (int): The number of effective samples for the MCMC algorithm. + + Returns: + tuple: A tuple containing the final state of the MCMC algorithm and the final hyperparameters. + + Raises: + None + + Examples: + # Define the kernel function + def kernel(x): + return x ** 2 + + # Define the initial state + initial_state = MCMCState(position=0, momentum=1) + + # Generate a random number generator key + rng_key = jax.random.PRNGKey(0) + + # Find the optimal parameters for the MCLMC algorithm + final_state, final_params = mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=1000, + state=initial_state, + rng_key=rng_key, + frac_tune1=0.2, + frac_tune2=0.3, + frac_tune3=0.1, + desired_energy_var=1e-4, + trust_in_estimate=2.0, + num_effective_samples=200, + ) + """ + dim = pytree_size(state.position) + params = MCLMCAdaptationState(jnp.sqrt(dim), jnp.sqrt(dim) * 0.25) + part1_key, part2_key = jax.random.split(rng_key, 2) + + state, params = make_L_step_size_adaptation( + kernel=mclmc_kernel, + dim=dim, + frac_tune1=frac_tune1, + frac_tune2=frac_tune2, + desired_energy_var=desired_energy_var, + trust_in_estimate=trust_in_estimate, + num_effective_samples=num_effective_samples, + )(state, params, num_steps, part1_key) + + if frac_tune3 != 0: + state, params = make_adaptation_L(mclmc_kernel, frac=frac_tune3, Lfactor=0.4)( + state, params, num_steps, part2_key + ) + + return state, params + + +def make_L_step_size_adaptation( + kernel, + dim, + frac_tune1, + frac_tune2, + desired_energy_var=1e-3, + trust_in_estimate=1.5, + num_effective_samples=150, +): + """Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC""" + + decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) + + def predictor(previous_state, params, adaptive_state, rng_key): + """does one step with the dynamics and updates the prediction for the optimal stepsize + Designed for the unadjusted MCHMC""" + + time, x_average, step_size_max = adaptive_state + + # dynamics + next_state, info = kernel( + rng_key=rng_key, + state=previous_state, + L=params.L, + step_size=params.step_size, + ) + # step updating + success, state, step_size_max, energy_change = handle_nans( + previous_state, + next_state, + params.step_size, + step_size_max, + info.energy_change, + ) + + # Warning: var = 0 if there were nans, but we will give it a very small weight + xi = ( + jnp.square(energy_change) / (dim * desired_energy_var) + ) + 1e-8 # 1e-8 is added to avoid divergences in log xi + weight = jnp.exp( + -0.5 * jnp.square(jnp.log(xi) / (6.0 * trust_in_estimate)) + ) # the weight reduces the impact of stepsizes which are much larger on much smaller than the desired one. + + x_average = decay_rate * x_average + weight * ( + xi / jnp.power(params.step_size, 6.0) + ) + time = decay_rate * time + weight + step_size = jnp.power( + x_average / time, -1.0 / 6.0 + ) # We use the Var[E] = O(eps^6) relation here. + step_size = (step_size < step_size_max) * step_size + ( + step_size > step_size_max + ) * step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences + params_new = params._replace(step_size=step_size) + + return state, params_new, params_new, (time, x_average, step_size_max), success + + def update_kalman(x, state, outer_weight, success, step_size): + """kalman filter to estimate the size of the posterior""" + time, x_average, x_squared_average = state + weight = outer_weight * step_size * success + zero_prevention = 1 - outer_weight + x_average = (time * x_average + weight * x) / ( + time + weight + zero_prevention + ) # Update with a Kalman filter + x_squared_average = (time * x_squared_average + weight * jnp.square(x)) / ( + time + weight + zero_prevention + ) # Update with a Kalman filter + time += weight + return (time, x_average, x_squared_average) + + adap0 = (0.0, 0.0, jnp.inf) + + def step(iteration_state, weight_and_key): + """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" + + outer_weight, rng_key = weight_and_key + state, params, adaptive_state, kalman_state = iteration_state + state, params, params_final, adaptive_state, success = predictor( + state, params, adaptive_state, rng_key + ) + position, _ = ravel_pytree(state.position) + kalman_state = update_kalman( + position, kalman_state, outer_weight, success, params.step_size + ) + + return (state, params_final, adaptive_state, kalman_state), None + + def L_step_size_adaptation(state, params, num_steps, rng_key): + num_steps1, num_steps2 = int(num_steps * frac_tune1), int( + num_steps * frac_tune2 + ) + L_step_size_adaptation_keys = jax.random.split(rng_key, num_steps1 + num_steps2) + + # we use the last num_steps2 to compute the diagonal preconditioner + outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) + + # initial state of the kalman filter + kalman_state = (0.0, jnp.zeros(dim), jnp.zeros(dim)) + + # run the steps + kalman_state = jax.lax.scan( + step, + init=(state, params, adap0, kalman_state), + xs=(outer_weights, L_step_size_adaptation_keys), + length=num_steps1 + num_steps2, + )[0] + state, params, _, kalman_state_output = kalman_state + + L = params.L + # determine L + if num_steps2 != 0.0: + _, F1, F2 = kalman_state_output + variances = F2 - jnp.square(F1) + L = jnp.sqrt(jnp.sum(variances)) + + return state, MCLMCAdaptationState(L, params.step_size) + + return L_step_size_adaptation + + +def make_adaptation_L(kernel, frac, Lfactor): + """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" + + def adaptation_L(state, params, num_steps, key): + num_steps = int(num_steps * frac) + adaptation_L_keys = jax.random.split(key, num_steps) + + # run kernel in the normal way + state, info = jax.lax.scan( + f=lambda s, k: ( + kernel(rng_key=k, state=s, L=params.L, step_size=params.step_size) + ), + init=state, + xs=adaptation_L_keys, + ) + samples = info.transformed_position # tranform is the identity here + flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) + flat_samples = flat_samples.reshape(2, num_steps // 2, -1) + ESS = effective_sample_size(flat_samples) + + return state, params._replace( + L=Lfactor * params.step_size * jnp.mean(num_steps / ESS) + ) + + return adaptation_L + + +def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change): + """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" + + reduced_step_size = 0.8 + p, unravel_fn = ravel_pytree(next_state.position) + nonans = jnp.all(jnp.isfinite(p)) + state, step_size, kinetic_change = jax.tree_util.tree_map( + lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), + (next_state, step_size_max, kinetic_change), + (previous_state, step_size * reduced_step_size, 0.0), + ) + + return nonans, state, step_size, kinetic_change diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index a1e1a42e0..f27b199c6 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -5,6 +5,7 @@ hmc, mala, marginal_latent_gaussian, + mclmc, nuts, periodic_orbital, random_walk, @@ -20,4 +21,5 @@ "periodic_orbital", "marginal_latent_gaussian", "random_walk", + "mclmc", ] diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index e871b6211..840693f81 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -365,5 +365,5 @@ def noneuclidean_integrator( noneuclidean_leapfrog = generate_noneuclidean_integrator(velocity_verlet_cofficients) -noneuclidean_mclachlan = generate_noneuclidean_integrator(mclachlan_cofficients) noneuclidean_yoshida = generate_noneuclidean_integrator(yoshida_cofficients) +noneuclidean_mclachlan = generate_noneuclidean_integrator(mclachlan_cofficients) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py new file mode 100644 index 000000000..a84bcaa44 --- /dev/null +++ b/blackjax/mcmc/mclmc.py @@ -0,0 +1,205 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the MCLMC Kernel""" +from typing import Callable, NamedTuple + +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree +from jax.random import normal + +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan +from blackjax.types import Array, ArrayLike, PRNGKey +from blackjax.util import generate_unit_vector, pytree_size + +__all__ = ["MCLMCInfo", "init", "build_kernel", "mclmc"] + + +class MCLMCInfo(NamedTuple): + """ + Additional information on the MCLMC transition. + + Attributes + ---------- + transformed_position : + The value of the samples after a transformation. This is typically a projection onto a lower dimensional subspace. + logdensity : + The log-density of the distribution at the current step of the MCLMC chain. + energy_change : + The difference in energy between the current and previous step. + """ + + transformed_position: Array + logdensity: float + kinetic_change: float + energy_change: float + + +def init(x_initial: ArrayLike, logdensity_fn, rng_key): + l, g = jax.value_and_grad(logdensity_fn)(x_initial) + + return IntegratorState( + position=x_initial, + momentum=generate_unit_vector(rng_key, x_initial), + logdensity=l, + logdensity_grad=g, + ) + + +def build_kernel(logdensity_fn, integrator, transform): + """Build a HMC kernel. + + Parameters + ---------- + integrator + The symplectic integrator to use to integrate the Hamiltonian dynamics. + transform + Value of the difference in energy above which we consider that the transition is divergent. + L + the momentum decoherence rate + step_size + step size of the integrator + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + + """ + step = integrator(logdensity_fn) + + def kernel( + rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float + ) -> tuple[IntegratorState, MCLMCInfo]: + (position, momentum, logdensity, logdensitygrad), kinetic_change = step( + state, step_size + ) + + dim = pytree_size(position) + + # Langevin-like noise + momentum, dim = partially_refresh_momentum( + momentum=momentum, rng_key=rng_key, L=L, step_size=step_size + ) + + return IntegratorState( + position, momentum, logdensity, logdensitygrad + ), MCLMCInfo( + transformed_position=transform(position), + logdensity=logdensity, + energy_change=kinetic_change - logdensity + state.logdensity, + kinetic_change=kinetic_change * (dim - 1), + ) + + return kernel + + +class mclmc: + """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be + cumbersome to manipulate. Since most users only need to specify the kernel + parameters at initialization time, we provide a helper function that + specializes the general kernel. + + We also add the general kernel and state generator as an attribute to this class so + users only need to pass `blackjax.mclmc` to SMC, adaptation, etc. algorithms. + + Examples + -------- + + A new mclmc kernel can be initialized and used with the following code: + + .. code:: + + mclmc = blackjax.mcmc.mclmc.mclmc( + logdensity_fn=logdensity_fn, + transform=lambda x: x, + L=L, + step_size=step_size + ) + state = mclmc.init(position) + new_state, info = mclmc.step(rng_key, state) + + Kernels are not jit-compiled by default so you will need to do it manually: + + .. code:: + + step = jax.jit(mclmc.step) + new_state, info = step(rng_key, state) + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + transform + A function to perform on the samples drawn from the target distribution + L + the momentum decoherence rate + step_size + step size of the integrator + integrator + an integrator. We recommend using the default here. + + Returns + ------- + A ``SamplingAlgorithm``. + """ + + init = staticmethod(init) + build_kernel = staticmethod(build_kernel) + + def __new__( # type: ignore[misc] + cls, + logdensity_fn: Callable, + L, + step_size, + transform: Callable = (lambda x: x), + integrator=noneuclidean_mclachlan, + seed=1, + ) -> SamplingAlgorithm: + kernel = cls.build_kernel(logdensity_fn, integrator, transform) + + def update_fn(rng_key, state): + return kernel(rng_key, state, L, step_size) + + def init_fn(position: ArrayLike): + return cls.init(position, logdensity_fn, jax.random.PRNGKey(seed)) + + return SamplingAlgorithm(init_fn, update_fn) + + +def partially_refresh_momentum(momentum, rng_key, step_size, L): + """Adds a small noise to momentum and normalizes. + + Parameters + ---------- + rng_key: + The pseudo-random number generator key used to generate random numbers. + momentum: + PyTree that the structure the output should to match. + step_size: + Step size + L: + controls rate of momentum change + + Returns + ------- + momentum with random change in angle + """ + m, unravel_fn = ravel_pytree(momentum) + dim = m.shape[0] + nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) + z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) + return unravel_fn((m + z) / jnp.linalg.norm(m + z)), dim diff --git a/docs/refs.bib b/docs/refs.bib index 7fcd081e7..eee65c7ea 100644 --- a/docs/refs.bib +++ b/docs/refs.bib @@ -379,6 +379,13 @@ @misc{robnik2023microcanonical primaryClass={stat.CO} } +@misc{robnik2023microcanonical2, + title={Microcanonical Langevin Monte Carlo}, + author={Robnik, Jakob and Seljak, Uro{\v{s}}}, + journal={arXiv preprint arXiv:2303.18221}, + year={2023} +} + @article{Livingstone2022Barker, author = {Livingstone, Samuel and Zanella, Giacomo}, title = "{The Barker Proposal: Combining Robustness and Efficiency in Gradient-Based MCMC}", diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index ebffcffc7..db9ef9944 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -84,6 +84,48 @@ def regression_logprob(self, log_scale, coefs, preds, x): # reduce sum otherwise broacasting will make the logprob biased. return sum(x.sum() for x in [scale_prior, coefs_prior, logpdf]) + def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): + init_key, tune_key, run_key = jax.random.split(key, 3) + + initial_state = blackjax.mcmc.mclmc.init( + x_initial=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key + ) + + kernel = blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=logdensity_fn, + integrator=blackjax.mcmc.integrators.noneuclidean_mclachlan, + transform=lambda x: x, + ) + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + ) = blackjax.mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + ) + + keys = jax.random.split(run_key, num_steps) + + sampling_alg = blackjax.mclmc( + logdensity_fn, + L=blackjax_mclmc_sampler_params.L, + step_size=blackjax_mclmc_sampler_params.step_size, + ) + + _, blackjax_mclmc_result = jax.lax.scan( + f=lambda state, k: sampling_alg.step( + rng_key=k, + state=state, + ), + xs=keys, + init=blackjax_state_after_tuning, + ) + + return blackjax_mclmc_result.transformed_position + @parameterized.parameters(itertools.product(regression_test_cases, [True, False])) def test_window_adaptation(self, case, is_mass_matrix_diagonal): """Test the HMC kernel and the Stan warmup.""" @@ -143,6 +185,30 @@ def test_mala(self): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) + def test_mclmc(self): + """Test the MCLMC kernel.""" + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_(**x) + + states = self.run_mclmc( + initial_position={"coefs": 1.0, "log_scale": 1.0}, + logdensity_fn=logdensity_fn, + key=inference_key, + num_steps=10000, + ) + + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + @parameterized.parameters(regression_test_cases) def test_pathfinder_adaptation( self,