Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve 'Functions to run kernels' #598

Merged
merged 9 commits into from
Dec 6, 2023
48 changes: 47 additions & 1 deletion blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import jax.numpy as jnp
from jax import jit, lax
from jax.flatten_util import ravel_pytree
from jax.random import normal
from jax.random import normal, split
from jax.tree_util import tree_leaves

from blackjax.base import Info, State
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey


Expand Down Expand Up @@ -136,3 +137,48 @@ def index_pytree(input_pytree: ArrayLikeTree) -> ArrayTree:
(dim_input,) = flat_input.shape
array = jnp.arange(dim_input, dtype=flat_input.dtype)
return unravel_fn(array)


def run_inference_algorithm(
rng_key,
initial_state_or_position,
inference_algorithm,
num_steps,
) -> tuple[State, State, Info]:
"""Wrapper to run an inference algorithm.

Parameters
----------
rng_key : PRNGKey
PaulScemama marked this conversation as resolved.
Show resolved Hide resolved
The random state used by JAX's random numbers generator.
initial_state_or_position: ArrayLikeTree
The initial state OR the initial position of the inference algorithm. If an initial position
is passed in, the function will automatically convert it into an initial state.
inference_algorithm : Union[SamplingAlgorithm, VIAlgorithm]
One of blackjax's sampling algorithms or variational inference algorithms.
num_steps : int
Number of learning steps.

Returns
-------
Tuple[State, State, Info]
1. The final state of the inference algorithm.
2. The history of states of the inference algorithm.
3. The history of the info of the inference algorithm.
"""
try:
initial_state = inference_algorithm.init(initial_state_or_position)
except TypeError:
# We assume initial_state is already in the right format.
initial_state = initial_state_or_position
initial_state = initial_state_or_position

keys = split(rng_key, num_steps)

@jit
def one_step(state, rng_key):
state, info = inference_algorithm.step(rng_key, state)
return state, (state, info)

final_state, (state_history, info_history) = lax.scan(one_step, initial_state, keys)
return final_state, state_history, info_history
14 changes: 6 additions & 8 deletions tests/adaptation/test_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import blackjax
from blackjax.adaptation import window_adaptation
from blackjax.util import run_inference_algorithm


@pytest.mark.parametrize(
Expand Down Expand Up @@ -57,15 +58,12 @@ def test_chees_adaptation():
optim=optax.adamw(learning_rate=0.5),
num_steps=num_burnin_steps,
)
kernel = blackjax.dynamic_hmc(logprob_fn, **parameters).step
algorithm = blackjax.dynamic_hmc(logprob_fn, **parameters)

def one_step(states, rng_key):
keys = jax.random.split(rng_key, num_chains)
states, infos = jax.vmap(kernel)(keys, states)
return states, infos

keys = jax.random.split(inference_key, num_results)
_, infos = jax.lax.scan(one_step, last_states, keys)
chain_keys = jax.random.split(inference_key, num_chains)
_, _, infos = jax.vmap(
lambda key, state: run_inference_algorithm(key, state, algorithm, num_results)
)(chain_keys, last_states)

harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate)
np.testing.assert_allclose(harmonic_mean, 0.75, rtol=1e-1)
Expand Down
107 changes: 64 additions & 43 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,7 @@
import blackjax
import blackjax.diagnostics as diagnostics
import blackjax.mcmc.random_walk


def inference_loop(kernel, num_samples, rng_key, initial_state):
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state

keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)

return states
from blackjax.util import run_inference_algorithm


def orbit_samples(orbits, weights, rng_key):
Expand Down Expand Up @@ -152,10 +142,10 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal):
case["initial_position"],
case["num_warmup_steps"],
)
algorithm = case["algorithm"](logposterior_fn, **parameters)
inference_algorithm = case["algorithm"](logposterior_fn, **parameters)

states = inference_loop(
algorithm.step, case["num_sampling_steps"], inference_key, state
_, states, _ = run_inference_algorithm(
inference_key, state, inference_algorithm, case["num_sampling_steps"]
)

coefs_samples = states.position["coefs"]
Expand All @@ -177,7 +167,7 @@ def test_mala(self):

mala = blackjax.mala(logposterior_fn, 1e-5)
state = mala.init({"coefs": 1.0, "log_scale": 1.0})
states = inference_loop(mala.step, 10_000, inference_key, state)
_, states, _ = run_inference_algorithm(inference_key, state, mala, 10_000)

coefs_samples = states.position["coefs"][3000:]
scale_samples = np.exp(states.position["log_scale"][3000:])
Expand Down Expand Up @@ -240,9 +230,11 @@ def test_pathfinder_adaptation(
initial_position,
num_warmup_steps,
)
kernel = algorithm(logposterior_fn, **parameters).step
inference_algorithm = algorithm(logposterior_fn, **parameters)

states = inference_loop(kernel, num_sampling_steps, inference_key, state)
_, states, _ = run_inference_algorithm(
inference_key, state, inference_algorithm, num_sampling_steps
)

coefs_samples = states.position["coefs"]
scale_samples = np.exp(states.position["log_scale"])
Expand Down Expand Up @@ -277,12 +269,14 @@ def test_meads(self):
initial_positions,
num_steps=1000,
)
kernel = blackjax.ghmc(logposterior_fn, **parameters).step
inference_algorithm = blackjax.ghmc(logposterior_fn, **parameters)

chain_keys = jax.random.split(inference_key, num_chains)
states = jax.vmap(lambda key, state: inference_loop(kernel, 100, key, state))(
chain_keys, last_states
)
_, states, _ = jax.vmap(
lambda key, state: run_inference_algorithm(
key, state, inference_algorithm, 100
)
)(chain_keys, last_states)

coefs_samples = states.position["coefs"]
scale_samples = np.exp(states.position["log_scale"])
Expand Down Expand Up @@ -319,12 +313,14 @@ def test_chees(self, jitter_generator):
optim=optax.adam(learning_rate=0.1),
num_steps=1000,
)
kernel = blackjax.dynamic_hmc(logposterior_fn, **parameters).step
inference_algorithm = blackjax.dynamic_hmc(logposterior_fn, **parameters)

chain_keys = jax.random.split(inference_key, num_chains)
states = jax.vmap(lambda key, state: inference_loop(kernel, 100, key, state))(
chain_keys, last_states
)
_, states, _ = jax.vmap(
lambda key, state: run_inference_algorithm(
key, state, inference_algorithm, 100
)
)(chain_keys, last_states)

coefs_samples = states.position["coefs"]
scale_samples = np.exp(states.position["log_scale"])
Expand All @@ -345,7 +341,8 @@ def test_barker(self):

barker = blackjax.barker_proposal(logposterior_fn, 1e-1)
state = barker.init({"coefs": 1.0, "log_scale": 1.0})
states = inference_loop(barker.step, 10_000, inference_key, state)

_, states, _ = run_inference_algorithm(inference_key, state, barker, 10_000)

coefs_samples = states.position["coefs"][3000:]
scale_samples = np.exp(states.position["log_scale"][3000:])
Expand Down Expand Up @@ -519,13 +516,24 @@ def setUp(self):
def test_latent_gaussian(self):
from blackjax import mgrad_gaussian

init, step = mgrad_gaussian(lambda x: -0.5 * jnp.sum((x - 1.0) ** 2), self.C)
inference_algorithm = mgrad_gaussian(
lambda x: -0.5 * jnp.sum((x - 1.0) ** 2), self.C
)
inference_algorithm = inference_algorithm._replace(
step=functools.partial(
inference_algorithm.step,
delta=self.delta,
)
)

kernel = lambda key, x: step(key, x, self.delta)
initial_state = init(jnp.zeros((1,)))
initial_state = inference_algorithm.init(jnp.zeros((1,)))

states = self.variant(
functools.partial(inference_loop, kernel, self.sampling_steps),
_, states, _ = self.variant(
functools.partial(
run_inference_algorithm,
inference_algorithm=inference_algorithm,
num_steps=self.sampling_steps,
),
)(self.key, initial_state)

np.testing.assert_allclose(
Expand Down Expand Up @@ -647,20 +655,25 @@ def test_univariate_normal(
if algorithm == blackjax.rmh:
parameters["proposal_generator"] = rmh_proposal_distribution

algo = algorithm(self.normal_logprob, **parameters)
inference_algorithm = algorithm(self.normal_logprob, **parameters)
rng_key = self.key
if algorithm == blackjax.elliptical_slice:
algo = algorithm(lambda _: 1.0, **parameters)
inference_algorithm = algorithm(lambda x: jnp.ones_like(x), **parameters)
if algorithm == blackjax.ghmc:
rng_key, initial_state_key = jax.random.split(rng_key)
initial_state = algo.init(initial_position, initial_state_key)
initial_state = inference_algorithm.init(
initial_position, initial_state_key
)
else:
initial_state = algo.init(initial_position)
initial_state = inference_algorithm.init(initial_position)

inference_key, orbit_key = jax.random.split(rng_key)
kernel = algo.step
states = self.variant(
functools.partial(inference_loop, kernel, num_sampling_steps)
_, states, _ = self.variant(
functools.partial(
run_inference_algorithm,
inference_algorithm=inference_algorithm,
num_steps=num_sampling_steps,
)
)(inference_key, initial_state)

if algorithm == blackjax.orbital_hmc:
Expand Down Expand Up @@ -763,23 +776,31 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal):
inverse_mass_matrix = true_scale**2
else:
inverse_mass_matrix = true_cov
kernel = algorithm(
inference_algorithm = algorithm(
logdensity_fn,
inverse_mass_matrix=inverse_mass_matrix,
**parameters,
)
else:
kernel = algorithm(logdensity_fn, **parameters)
inference_algorithm = algorithm(logdensity_fn, **parameters)

num_chains = 10
initial_positions = jax.random.normal(pos_init_key, [num_chains, 2])
initial_states = jax.vmap(kernel.init, in_axes=(0,))(initial_positions)
initial_states = jax.vmap(inference_algorithm.init, in_axes=(0,))(
initial_positions
)
multi_chain_sample_key = jax.random.split(sample_key, num_chains)

inference_loop_multiple_chains = jax.vmap(
functools.partial(inference_loop, kernel.step, 2_000)
functools.partial(
run_inference_algorithm,
inference_algorithm=inference_algorithm,
num_steps=2_000,
)
)
_, states, _ = inference_loop_multiple_chains(
multi_chain_sample_key, initial_states
)
states = inference_loop_multiple_chains(multi_chain_sample_key, initial_states)

posterior_samples = states.position[:, -1000:]
posterior_delta = posterior_samples - true_loc
Expand Down
19 changes: 5 additions & 14 deletions tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pytest

import blackjax
from blackjax.util import run_inference_algorithm


def regression_logprob(log_scale, coefs, preds, x):
Expand All @@ -25,18 +26,6 @@ def regression_logprob(log_scale, coefs, preds, x):
return sum(x.sum() for x in [scale_prior, coefs_prior, logpdf])


def inference_loop(kernel, num_samples, rng_key, initial_state):
@jax.jit
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state

keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)

return states


def run_regression(algorithm, **parameters):
key = jax.random.key(0)
rng_key, init_key0, init_key1 = jax.random.split(key, 3)
Expand All @@ -57,9 +46,11 @@ def run_regression(algorithm, **parameters):
(state, parameters), _ = warmup.run(
warmup_key, {"log_scale": 0.0, "coefs": 2.0}, 1000
)
kernel = algorithm(logdensity_fn, **parameters).step
inference_algorithm = algorithm(logdensity_fn, **parameters)

states = inference_loop(kernel, 10_000, inference_key, state)
_, states, _ = run_inference_algorithm(
inference_key, state, inference_algorithm, 10_000
)

return states

Expand Down