Skip to content

Commit

Permalink
Resolve 'Functions to run kernels' (#598)
Browse files Browse the repository at this point in the history
* Add  function, modify  to account for change

* Revert back quickstart.md

* Add run_inference wrapper to tests/mcmc/sampling; Get rid of arg types for run_inference wrapper

* Get rid of unused imports

* Add run_inference wrapper to tests/benchmark

* Change import style

* Import style for benchmarks test; add wrapper to adaptation test

* Replace 'kernel' variable name with 'inference algorithm' when using wrapper run_inference_algorithm

---------

Co-authored-by: Paul Scemama <[email protected]>
  • Loading branch information
2 people authored and junpenglao committed Mar 12, 2024
1 parent d1e7014 commit b9f93fc
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 66 deletions.
48 changes: 47 additions & 1 deletion blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import jax.numpy as jnp
from jax import jit, lax
from jax.flatten_util import ravel_pytree
from jax.random import normal
from jax.random import normal, split
from jax.tree_util import tree_leaves

from blackjax.base import Info, State
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey


Expand Down Expand Up @@ -136,3 +137,48 @@ def index_pytree(input_pytree: ArrayLikeTree) -> ArrayTree:
(dim_input,) = flat_input.shape
array = jnp.arange(dim_input, dtype=flat_input.dtype)
return unravel_fn(array)


def run_inference_algorithm(
rng_key,
initial_state_or_position,
inference_algorithm,
num_steps,
) -> tuple[State, State, Info]:
"""Wrapper to run an inference algorithm.
Parameters
----------
rng_key : PRNGKey
The random state used by JAX's random numbers generator.
initial_state_or_position: ArrayLikeTree
The initial state OR the initial position of the inference algorithm. If an initial position
is passed in, the function will automatically convert it into an initial state.
inference_algorithm : Union[SamplingAlgorithm, VIAlgorithm]
One of blackjax's sampling algorithms or variational inference algorithms.
num_steps : int
Number of learning steps.
Returns
-------
Tuple[State, State, Info]
1. The final state of the inference algorithm.
2. The history of states of the inference algorithm.
3. The history of the info of the inference algorithm.
"""
try:
initial_state = inference_algorithm.init(initial_state_or_position)
except TypeError:
# We assume initial_state is already in the right format.
initial_state = initial_state_or_position
initial_state = initial_state_or_position

keys = split(rng_key, num_steps)

@jit
def one_step(state, rng_key):
state, info = inference_algorithm.step(rng_key, state)
return state, (state, info)

final_state, (state_history, info_history) = lax.scan(one_step, initial_state, keys)
return final_state, state_history, info_history
14 changes: 6 additions & 8 deletions tests/adaptation/test_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import blackjax
from blackjax.adaptation import window_adaptation
from blackjax.util import run_inference_algorithm


@pytest.mark.parametrize(
Expand Down Expand Up @@ -57,15 +58,12 @@ def test_chees_adaptation():
optim=optax.adamw(learning_rate=0.5),
num_steps=num_burnin_steps,
)
kernel = blackjax.dynamic_hmc(logprob_fn, **parameters).step
algorithm = blackjax.dynamic_hmc(logprob_fn, **parameters)

def one_step(states, rng_key):
keys = jax.random.split(rng_key, num_chains)
states, infos = jax.vmap(kernel)(keys, states)
return states, infos

keys = jax.random.split(inference_key, num_results)
_, infos = jax.lax.scan(one_step, last_states, keys)
chain_keys = jax.random.split(inference_key, num_chains)
_, _, infos = jax.vmap(
lambda key, state: run_inference_algorithm(key, state, algorithm, num_results)
)(chain_keys, last_states)

harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate)
np.testing.assert_allclose(harmonic_mean, 0.75, rtol=1e-1)
Expand Down
107 changes: 64 additions & 43 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,7 @@
import blackjax
import blackjax.diagnostics as diagnostics
import blackjax.mcmc.random_walk


def inference_loop(kernel, num_samples, rng_key, initial_state):
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state

keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)

return states
from blackjax.util import run_inference_algorithm


def orbit_samples(orbits, weights, rng_key):
Expand Down Expand Up @@ -152,10 +142,10 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal):
case["initial_position"],
case["num_warmup_steps"],
)
algorithm = case["algorithm"](logposterior_fn, **parameters)
inference_algorithm = case["algorithm"](logposterior_fn, **parameters)

states = inference_loop(
algorithm.step, case["num_sampling_steps"], inference_key, state
_, states, _ = run_inference_algorithm(
inference_key, state, inference_algorithm, case["num_sampling_steps"]
)

coefs_samples = states.position["coefs"]
Expand All @@ -177,7 +167,7 @@ def test_mala(self):

mala = blackjax.mala(logposterior_fn, 1e-5)
state = mala.init({"coefs": 1.0, "log_scale": 1.0})
states = inference_loop(mala.step, 10_000, inference_key, state)
_, states, _ = run_inference_algorithm(inference_key, state, mala, 10_000)

coefs_samples = states.position["coefs"][3000:]
scale_samples = np.exp(states.position["log_scale"][3000:])
Expand Down Expand Up @@ -240,9 +230,11 @@ def test_pathfinder_adaptation(
initial_position,
num_warmup_steps,
)
kernel = algorithm(logposterior_fn, **parameters).step
inference_algorithm = algorithm(logposterior_fn, **parameters)

states = inference_loop(kernel, num_sampling_steps, inference_key, state)
_, states, _ = run_inference_algorithm(
inference_key, state, inference_algorithm, num_sampling_steps
)

coefs_samples = states.position["coefs"]
scale_samples = np.exp(states.position["log_scale"])
Expand Down Expand Up @@ -277,12 +269,14 @@ def test_meads(self):
initial_positions,
num_steps=1000,
)
kernel = blackjax.ghmc(logposterior_fn, **parameters).step
inference_algorithm = blackjax.ghmc(logposterior_fn, **parameters)

chain_keys = jax.random.split(inference_key, num_chains)
states = jax.vmap(lambda key, state: inference_loop(kernel, 100, key, state))(
chain_keys, last_states
)
_, states, _ = jax.vmap(
lambda key, state: run_inference_algorithm(
key, state, inference_algorithm, 100
)
)(chain_keys, last_states)

coefs_samples = states.position["coefs"]
scale_samples = np.exp(states.position["log_scale"])
Expand Down Expand Up @@ -319,12 +313,14 @@ def test_chees(self, jitter_generator):
optim=optax.adam(learning_rate=0.1),
num_steps=1000,
)
kernel = blackjax.dynamic_hmc(logposterior_fn, **parameters).step
inference_algorithm = blackjax.dynamic_hmc(logposterior_fn, **parameters)

chain_keys = jax.random.split(inference_key, num_chains)
states = jax.vmap(lambda key, state: inference_loop(kernel, 100, key, state))(
chain_keys, last_states
)
_, states, _ = jax.vmap(
lambda key, state: run_inference_algorithm(
key, state, inference_algorithm, 100
)
)(chain_keys, last_states)

coefs_samples = states.position["coefs"]
scale_samples = np.exp(states.position["log_scale"])
Expand All @@ -345,7 +341,8 @@ def test_barker(self):

barker = blackjax.barker_proposal(logposterior_fn, 1e-1)
state = barker.init({"coefs": 1.0, "log_scale": 1.0})
states = inference_loop(barker.step, 10_000, inference_key, state)

_, states, _ = run_inference_algorithm(inference_key, state, barker, 10_000)

coefs_samples = states.position["coefs"][3000:]
scale_samples = np.exp(states.position["log_scale"][3000:])
Expand Down Expand Up @@ -519,13 +516,24 @@ def setUp(self):
def test_latent_gaussian(self):
from blackjax import mgrad_gaussian

init, step = mgrad_gaussian(lambda x: -0.5 * jnp.sum((x - 1.0) ** 2), self.C)
inference_algorithm = mgrad_gaussian(
lambda x: -0.5 * jnp.sum((x - 1.0) ** 2), self.C
)
inference_algorithm = inference_algorithm._replace(
step=functools.partial(
inference_algorithm.step,
delta=self.delta,
)
)

kernel = lambda key, x: step(key, x, self.delta)
initial_state = init(jnp.zeros((1,)))
initial_state = inference_algorithm.init(jnp.zeros((1,)))

states = self.variant(
functools.partial(inference_loop, kernel, self.sampling_steps),
_, states, _ = self.variant(
functools.partial(
run_inference_algorithm,
inference_algorithm=inference_algorithm,
num_steps=self.sampling_steps,
),
)(self.key, initial_state)

np.testing.assert_allclose(
Expand Down Expand Up @@ -647,20 +655,25 @@ def test_univariate_normal(
if algorithm == blackjax.rmh:
parameters["proposal_generator"] = rmh_proposal_distribution

algo = algorithm(self.normal_logprob, **parameters)
inference_algorithm = algorithm(self.normal_logprob, **parameters)
rng_key = self.key
if algorithm == blackjax.elliptical_slice:
algo = algorithm(lambda _: 1.0, **parameters)
inference_algorithm = algorithm(lambda x: jnp.ones_like(x), **parameters)
if algorithm == blackjax.ghmc:
rng_key, initial_state_key = jax.random.split(rng_key)
initial_state = algo.init(initial_position, initial_state_key)
initial_state = inference_algorithm.init(
initial_position, initial_state_key
)
else:
initial_state = algo.init(initial_position)
initial_state = inference_algorithm.init(initial_position)

inference_key, orbit_key = jax.random.split(rng_key)
kernel = algo.step
states = self.variant(
functools.partial(inference_loop, kernel, num_sampling_steps)
_, states, _ = self.variant(
functools.partial(
run_inference_algorithm,
inference_algorithm=inference_algorithm,
num_steps=num_sampling_steps,
)
)(inference_key, initial_state)

if algorithm == blackjax.orbital_hmc:
Expand Down Expand Up @@ -763,23 +776,31 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal):
inverse_mass_matrix = true_scale**2
else:
inverse_mass_matrix = true_cov
kernel = algorithm(
inference_algorithm = algorithm(
logdensity_fn,
inverse_mass_matrix=inverse_mass_matrix,
**parameters,
)
else:
kernel = algorithm(logdensity_fn, **parameters)
inference_algorithm = algorithm(logdensity_fn, **parameters)

num_chains = 10
initial_positions = jax.random.normal(pos_init_key, [num_chains, 2])
initial_states = jax.vmap(kernel.init, in_axes=(0,))(initial_positions)
initial_states = jax.vmap(inference_algorithm.init, in_axes=(0,))(
initial_positions
)
multi_chain_sample_key = jax.random.split(sample_key, num_chains)

inference_loop_multiple_chains = jax.vmap(
functools.partial(inference_loop, kernel.step, 2_000)
functools.partial(
run_inference_algorithm,
inference_algorithm=inference_algorithm,
num_steps=2_000,
)
)
_, states, _ = inference_loop_multiple_chains(
multi_chain_sample_key, initial_states
)
states = inference_loop_multiple_chains(multi_chain_sample_key, initial_states)

posterior_samples = states.position[:, -1000:]
posterior_delta = posterior_samples - true_loc
Expand Down
19 changes: 5 additions & 14 deletions tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pytest

import blackjax
from blackjax.util import run_inference_algorithm


def regression_logprob(log_scale, coefs, preds, x):
Expand All @@ -25,18 +26,6 @@ def regression_logprob(log_scale, coefs, preds, x):
return sum(x.sum() for x in [scale_prior, coefs_prior, logpdf])


def inference_loop(kernel, num_samples, rng_key, initial_state):
@jax.jit
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state

keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)

return states


def run_regression(algorithm, **parameters):
key = jax.random.key(0)
rng_key, init_key0, init_key1 = jax.random.split(key, 3)
Expand All @@ -57,9 +46,11 @@ def run_regression(algorithm, **parameters):
(state, parameters), _ = warmup.run(
warmup_key, {"log_scale": 0.0, "coefs": 2.0}, 1000
)
kernel = algorithm(logdensity_fn, **parameters).step
inference_algorithm = algorithm(logdensity_fn, **parameters)

states = inference_loop(kernel, 10_000, inference_key, state)
_, states, _ = run_inference_algorithm(
inference_key, state, inference_algorithm, 10_000
)

return states

Expand Down

0 comments on commit b9f93fc

Please sign in to comment.