From b9f93fc9587627af6c626d4b9c8c1b3ff3c3b6fe Mon Sep 17 00:00:00 2001 From: Paul Scemama Date: Wed, 6 Dec 2023 12:53:19 -0500 Subject: [PATCH] Resolve 'Functions to run kernels' (#598) * Add function, modify to account for change * Revert back quickstart.md * Add run_inference wrapper to tests/mcmc/sampling; Get rid of arg types for run_inference wrapper * Get rid of unused imports * Add run_inference wrapper to tests/benchmark * Change import style * Import style for benchmarks test; add wrapper to adaptation test * Replace 'kernel' variable name with 'inference algorithm' when using wrapper run_inference_algorithm --------- Co-authored-by: Paul Scemama --- blackjax/util.py | 48 ++++++++++++- tests/adaptation/test_adaptation.py | 14 ++-- tests/mcmc/test_sampling.py | 107 +++++++++++++++++----------- tests/test_benchmarks.py | 19 ++--- 4 files changed, 122 insertions(+), 66 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index a3a7226a6..f667d147a 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -5,9 +5,10 @@ import jax.numpy as jnp from jax import jit, lax from jax.flatten_util import ravel_pytree -from jax.random import normal +from jax.random import normal, split from jax.tree_util import tree_leaves +from blackjax.base import Info, State from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -136,3 +137,48 @@ def index_pytree(input_pytree: ArrayLikeTree) -> ArrayTree: (dim_input,) = flat_input.shape array = jnp.arange(dim_input, dtype=flat_input.dtype) return unravel_fn(array) + + +def run_inference_algorithm( + rng_key, + initial_state_or_position, + inference_algorithm, + num_steps, +) -> tuple[State, State, Info]: + """Wrapper to run an inference algorithm. + + Parameters + ---------- + rng_key : PRNGKey + The random state used by JAX's random numbers generator. + initial_state_or_position: ArrayLikeTree + The initial state OR the initial position of the inference algorithm. If an initial position + is passed in, the function will automatically convert it into an initial state. + inference_algorithm : Union[SamplingAlgorithm, VIAlgorithm] + One of blackjax's sampling algorithms or variational inference algorithms. + num_steps : int + Number of learning steps. + + Returns + ------- + Tuple[State, State, Info] + 1. The final state of the inference algorithm. + 2. The history of states of the inference algorithm. + 3. The history of the info of the inference algorithm. + """ + try: + initial_state = inference_algorithm.init(initial_state_or_position) + except TypeError: + # We assume initial_state is already in the right format. + initial_state = initial_state_or_position + initial_state = initial_state_or_position + + keys = split(rng_key, num_steps) + + @jit + def one_step(state, rng_key): + state, info = inference_algorithm.step(rng_key, state) + return state, (state, info) + + final_state, (state_history, info_history) = lax.scan(one_step, initial_state, keys) + return final_state, state_history, info_history diff --git a/tests/adaptation/test_adaptation.py b/tests/adaptation/test_adaptation.py index 1b95b0115..93bf418d2 100644 --- a/tests/adaptation/test_adaptation.py +++ b/tests/adaptation/test_adaptation.py @@ -6,6 +6,7 @@ import blackjax from blackjax.adaptation import window_adaptation +from blackjax.util import run_inference_algorithm @pytest.mark.parametrize( @@ -57,15 +58,12 @@ def test_chees_adaptation(): optim=optax.adamw(learning_rate=0.5), num_steps=num_burnin_steps, ) - kernel = blackjax.dynamic_hmc(logprob_fn, **parameters).step + algorithm = blackjax.dynamic_hmc(logprob_fn, **parameters) - def one_step(states, rng_key): - keys = jax.random.split(rng_key, num_chains) - states, infos = jax.vmap(kernel)(keys, states) - return states, infos - - keys = jax.random.split(inference_key, num_results) - _, infos = jax.lax.scan(one_step, last_states, keys) + chain_keys = jax.random.split(inference_key, num_chains) + _, _, infos = jax.vmap( + lambda key, state: run_inference_algorithm(key, state, algorithm, num_results) + )(chain_keys, last_states) harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate) np.testing.assert_allclose(harmonic_mean, 0.75, rtol=1e-1) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index db9ef9944..ec47f1180 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -13,17 +13,7 @@ import blackjax import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk - - -def inference_loop(kernel, num_samples, rng_key, initial_state): - def one_step(state, rng_key): - state, _ = kernel(rng_key, state) - return state, state - - keys = jax.random.split(rng_key, num_samples) - _, states = jax.lax.scan(one_step, initial_state, keys) - - return states +from blackjax.util import run_inference_algorithm def orbit_samples(orbits, weights, rng_key): @@ -152,10 +142,10 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal): case["initial_position"], case["num_warmup_steps"], ) - algorithm = case["algorithm"](logposterior_fn, **parameters) + inference_algorithm = case["algorithm"](logposterior_fn, **parameters) - states = inference_loop( - algorithm.step, case["num_sampling_steps"], inference_key, state + _, states, _ = run_inference_algorithm( + inference_key, state, inference_algorithm, case["num_sampling_steps"] ) coefs_samples = states.position["coefs"] @@ -177,7 +167,7 @@ def test_mala(self): mala = blackjax.mala(logposterior_fn, 1e-5) state = mala.init({"coefs": 1.0, "log_scale": 1.0}) - states = inference_loop(mala.step, 10_000, inference_key, state) + _, states, _ = run_inference_algorithm(inference_key, state, mala, 10_000) coefs_samples = states.position["coefs"][3000:] scale_samples = np.exp(states.position["log_scale"][3000:]) @@ -240,9 +230,11 @@ def test_pathfinder_adaptation( initial_position, num_warmup_steps, ) - kernel = algorithm(logposterior_fn, **parameters).step + inference_algorithm = algorithm(logposterior_fn, **parameters) - states = inference_loop(kernel, num_sampling_steps, inference_key, state) + _, states, _ = run_inference_algorithm( + inference_key, state, inference_algorithm, num_sampling_steps + ) coefs_samples = states.position["coefs"] scale_samples = np.exp(states.position["log_scale"]) @@ -277,12 +269,14 @@ def test_meads(self): initial_positions, num_steps=1000, ) - kernel = blackjax.ghmc(logposterior_fn, **parameters).step + inference_algorithm = blackjax.ghmc(logposterior_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - states = jax.vmap(lambda key, state: inference_loop(kernel, 100, key, state))( - chain_keys, last_states - ) + _, states, _ = jax.vmap( + lambda key, state: run_inference_algorithm( + key, state, inference_algorithm, 100 + ) + )(chain_keys, last_states) coefs_samples = states.position["coefs"] scale_samples = np.exp(states.position["log_scale"]) @@ -319,12 +313,14 @@ def test_chees(self, jitter_generator): optim=optax.adam(learning_rate=0.1), num_steps=1000, ) - kernel = blackjax.dynamic_hmc(logposterior_fn, **parameters).step + inference_algorithm = blackjax.dynamic_hmc(logposterior_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - states = jax.vmap(lambda key, state: inference_loop(kernel, 100, key, state))( - chain_keys, last_states - ) + _, states, _ = jax.vmap( + lambda key, state: run_inference_algorithm( + key, state, inference_algorithm, 100 + ) + )(chain_keys, last_states) coefs_samples = states.position["coefs"] scale_samples = np.exp(states.position["log_scale"]) @@ -345,7 +341,8 @@ def test_barker(self): barker = blackjax.barker_proposal(logposterior_fn, 1e-1) state = barker.init({"coefs": 1.0, "log_scale": 1.0}) - states = inference_loop(barker.step, 10_000, inference_key, state) + + _, states, _ = run_inference_algorithm(inference_key, state, barker, 10_000) coefs_samples = states.position["coefs"][3000:] scale_samples = np.exp(states.position["log_scale"][3000:]) @@ -519,13 +516,24 @@ def setUp(self): def test_latent_gaussian(self): from blackjax import mgrad_gaussian - init, step = mgrad_gaussian(lambda x: -0.5 * jnp.sum((x - 1.0) ** 2), self.C) + inference_algorithm = mgrad_gaussian( + lambda x: -0.5 * jnp.sum((x - 1.0) ** 2), self.C + ) + inference_algorithm = inference_algorithm._replace( + step=functools.partial( + inference_algorithm.step, + delta=self.delta, + ) + ) - kernel = lambda key, x: step(key, x, self.delta) - initial_state = init(jnp.zeros((1,))) + initial_state = inference_algorithm.init(jnp.zeros((1,))) - states = self.variant( - functools.partial(inference_loop, kernel, self.sampling_steps), + _, states, _ = self.variant( + functools.partial( + run_inference_algorithm, + inference_algorithm=inference_algorithm, + num_steps=self.sampling_steps, + ), )(self.key, initial_state) np.testing.assert_allclose( @@ -647,20 +655,25 @@ def test_univariate_normal( if algorithm == blackjax.rmh: parameters["proposal_generator"] = rmh_proposal_distribution - algo = algorithm(self.normal_logprob, **parameters) + inference_algorithm = algorithm(self.normal_logprob, **parameters) rng_key = self.key if algorithm == blackjax.elliptical_slice: - algo = algorithm(lambda _: 1.0, **parameters) + inference_algorithm = algorithm(lambda x: jnp.ones_like(x), **parameters) if algorithm == blackjax.ghmc: rng_key, initial_state_key = jax.random.split(rng_key) - initial_state = algo.init(initial_position, initial_state_key) + initial_state = inference_algorithm.init( + initial_position, initial_state_key + ) else: - initial_state = algo.init(initial_position) + initial_state = inference_algorithm.init(initial_position) inference_key, orbit_key = jax.random.split(rng_key) - kernel = algo.step - states = self.variant( - functools.partial(inference_loop, kernel, num_sampling_steps) + _, states, _ = self.variant( + functools.partial( + run_inference_algorithm, + inference_algorithm=inference_algorithm, + num_steps=num_sampling_steps, + ) )(inference_key, initial_state) if algorithm == blackjax.orbital_hmc: @@ -763,23 +776,31 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): inverse_mass_matrix = true_scale**2 else: inverse_mass_matrix = true_cov - kernel = algorithm( + inference_algorithm = algorithm( logdensity_fn, inverse_mass_matrix=inverse_mass_matrix, **parameters, ) else: - kernel = algorithm(logdensity_fn, **parameters) + inference_algorithm = algorithm(logdensity_fn, **parameters) num_chains = 10 initial_positions = jax.random.normal(pos_init_key, [num_chains, 2]) - initial_states = jax.vmap(kernel.init, in_axes=(0,))(initial_positions) + initial_states = jax.vmap(inference_algorithm.init, in_axes=(0,))( + initial_positions + ) multi_chain_sample_key = jax.random.split(sample_key, num_chains) inference_loop_multiple_chains = jax.vmap( - functools.partial(inference_loop, kernel.step, 2_000) + functools.partial( + run_inference_algorithm, + inference_algorithm=inference_algorithm, + num_steps=2_000, + ) + ) + _, states, _ = inference_loop_multiple_chains( + multi_chain_sample_key, initial_states ) - states = inference_loop_multiple_chains(multi_chain_sample_key, initial_states) posterior_samples = states.position[:, -1000:] posterior_delta = posterior_samples - true_loc diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index d64efa4cd..d8f09cea0 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -13,6 +13,7 @@ import pytest import blackjax +from blackjax.util import run_inference_algorithm def regression_logprob(log_scale, coefs, preds, x): @@ -25,18 +26,6 @@ def regression_logprob(log_scale, coefs, preds, x): return sum(x.sum() for x in [scale_prior, coefs_prior, logpdf]) -def inference_loop(kernel, num_samples, rng_key, initial_state): - @jax.jit - def one_step(state, rng_key): - state, _ = kernel(rng_key, state) - return state, state - - keys = jax.random.split(rng_key, num_samples) - _, states = jax.lax.scan(one_step, initial_state, keys) - - return states - - def run_regression(algorithm, **parameters): key = jax.random.key(0) rng_key, init_key0, init_key1 = jax.random.split(key, 3) @@ -57,9 +46,11 @@ def run_regression(algorithm, **parameters): (state, parameters), _ = warmup.run( warmup_key, {"log_scale": 0.0, "coefs": 2.0}, 1000 ) - kernel = algorithm(logdensity_fn, **parameters).step + inference_algorithm = algorithm(logdensity_fn, **parameters) - states = inference_loop(kernel, 10_000, inference_key, state) + _, states, _ = run_inference_algorithm( + inference_key, state, inference_algorithm, 10_000 + ) return states