From 6b8b327bbfb68a387c5d09c7794a672b1d15979c Mon Sep 17 00:00:00 2001 From: Paul Scemama Date: Fri, 1 Dec 2023 02:13:53 +0000 Subject: [PATCH 1/8] Add function, modify to account for change --- blackjax/util.py | 49 ++++++++++++++++++++++++++-- docs/examples/quickstart.md | 64 ++++++++++++++----------------------- 2 files changed, 71 insertions(+), 42 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index a3a7226a6..af334a1a3 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -1,13 +1,14 @@ """Utility functions for BlackJax.""" from functools import partial -from typing import Union +from typing import Tuple, Union import jax.numpy as jnp from jax import jit, lax from jax.flatten_util import ravel_pytree -from jax.random import normal +from jax.random import normal, split from jax.tree_util import tree_leaves +from blackjax.base import Info, SamplingAlgorithm, State, VIAlgorithm from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -136,3 +137,47 @@ def index_pytree(input_pytree: ArrayLikeTree) -> ArrayTree: (dim_input,) = flat_input.shape array = jnp.arange(dim_input, dtype=flat_input.dtype) return unravel_fn(array) + + +def run_inference_algorithm( + rng_key: PRNGKey, + initial_state_or_position: ArrayLikeTree, + inference_algorithm: Union[SamplingAlgorithm, VIAlgorithm], + num_steps: int, +) -> Tuple[State, State, Info]: + """Wrapper to run an inference algorithm. + + Parameters + ---------- + rng_key : PRNGKey + The random state used by JAX's random numbers generator. + initial_state_or_position: ArrayLikeTree + The initial state OR the initial position of the inference algorithm. If an initial position + is passed in, the function will automatically convert it into an initial state. + inference_algorithm : Union[SamplingAlgorithm, VIAlgorithm] + One of blackjax's sampling algorithms or variational inference algorithms. + num_steps : int + Number of learning steps. + + Returns + ------- + Tuple[State, State, Info] + 1. The final state of the inference algorithm. + 2. The history of states of the inference algorithm. + 3. The history of the info of the inference algorithm. + """ + try: + initial_state = inference_algorithm.init(initial_state_or_position) + except TypeError: + # initial_state is already in the right format. + initial_state = initial_state_or_position + + keys = split(rng_key, num_steps) + + @jit + def one_step(state, rng_key): + state, info = inference_algorithm.step(rng_key, state) + return state, (state, info) + + final_state, (state_history, info_history) = lax.scan(one_step, initial_state, keys) + return final_state, state_history, info_history diff --git a/docs/examples/quickstart.md b/docs/examples/quickstart.md index 870e5df9a..f70241c85 100644 --- a/docs/examples/quickstart.md +++ b/docs/examples/quickstart.md @@ -17,7 +17,7 @@ BlackJAX is an MCMC sampling library based on [JAX](https://github.com/google/ja In this notebook we provide a simple example based on basic Hamiltonian Monte Carlo and the NUTS algorithm to showcase the architecture and interfaces in the library -```{code-cell} +```{code-cell} ipython3 import matplotlib.pyplot as plt import numpy as np @@ -26,9 +26,10 @@ import jax.numpy as jnp import jax.scipy.stats as stats import blackjax +from blackjax.util import run_inference_algorithm ``` -```{code-cell} +```{code-cell} ipython3 :tags: [remove-output] from datetime import date @@ -39,12 +40,12 @@ rng_key = jax.random.key(int(date.today().strftime("%Y%m%d"))) We'll generate observations from a normal distribution of known `loc` and `scale` to see if we can recover the parameters in sampling. **MCMC algorithms usually assume samples are being drawn from an unconstrained Euclidean space.** Hence why we'll log transform the scale parameter, so that sampling is done on the real line. Samples can be transformed back to their original space in post-processing. Let's take a decent-size dataset with 1,000 points: -```{code-cell} +```{code-cell} ipython3 loc, scale = 10, 20 observed = np.random.normal(loc, scale, size=1_000) ``` -```{code-cell} +```{code-cell} ipython3 def logdensity_fn(loc, log_scale, observed=observed): """Univariate Normal""" scale = jnp.exp(log_scale) @@ -59,7 +60,7 @@ logdensity = lambda x: logdensity_fn(**x) ### Sampler Parameters -```{code-cell} +```{code-cell} ipython3 inv_mass_matrix = np.array([0.5, 0.01]) num_integration_steps = 60 step_size = 1e-3 @@ -71,7 +72,7 @@ hmc = blackjax.hmc(logdensity, step_size, inv_mass_matrix, num_integration_steps The initial state of the HMC algorithm requires not only an initial position, but also the potential energy and gradient of the potential energy at this position (for example, in the context of Bayesian modeling, the output of the log posterior function evaluated at the initial position). BlackJAX provides a `new_state` function to initialize the state from an initial position. -```{code-cell} +```{code-cell} ipython3 initial_position = {"loc": 1.0, "log_scale": 1.0} initial_state = hmc.init(initial_position) initial_state @@ -82,37 +83,20 @@ initial_state The HMC kernel is easy to obtain: -```{code-cell} -hmc_kernel = jax.jit(hmc.step) -``` - -BlackJAX does not provide a default inference loop, but it easy to implement with JAX's `lax.scan`: - -```{code-cell} -def inference_loop(rng_key, kernel, initial_state, num_samples): - @jax.jit - def one_step(state, rng_key): - state, _ = kernel(rng_key, state) - return state, state - - keys = jax.random.split(rng_key, num_samples) - _, states = jax.lax.scan(one_step, initial_state, keys) - - return states -``` ++++ ### Inference -```{code-cell} +```{code-cell} ipython3 %%time rng_key, sample_key = jax.random.split(rng_key) -states = inference_loop(sample_key, hmc_kernel, initial_state, 10_000) +final_state, state_history, info_history = run_inference_algorithm(rng_key, initial_state, hmc, 10_000) -mcmc_samples = states.position +mcmc_samples = state_history.position mcmc_samples["scale"] = jnp.exp(mcmc_samples["log_scale"]).block_until_ready() ``` -```{code-cell} +```{code-cell} ipython3 :tags: [hide-input] fig, (ax, ax1) = plt.subplots(ncols=2, figsize=(15, 6)) @@ -129,29 +113,29 @@ ax1.set_ylabel("scale"); NUTS is a *dynamic* algorithm: the number of integration steps is determined at runtime. We still need to specify a step size and a mass matrix: -```{code-cell} +```{code-cell} ipython3 inv_mass_matrix = np.array([0.5, 0.01]) step_size = 1e-3 nuts = blackjax.nuts(logdensity, step_size, inv_mass_matrix) ``` -```{code-cell} +```{code-cell} ipython3 initial_position = {"loc": 1.0, "log_scale": 1.0} initial_state = nuts.init(initial_position) initial_state ``` -```{code-cell} +```{code-cell} ipython3 %%time rng_key, sample_key = jax.random.split(rng_key) -states = inference_loop(sample_key, nuts.step, initial_state, 4_000) +final_state, state_history, info_history = run_inference_algorithm(rng_key, initial_state, nuts, 4_000) -mcmc_samples = states.position +mcmc_samples = state_history.position mcmc_samples["scale"] = jnp.exp(mcmc_samples["log_scale"]).block_until_ready() ``` -```{code-cell} +```{code-cell} ipython3 :tags: [hide-input] fig, (ax, ax1) = plt.subplots(ncols=2, figsize=(15, 6)) @@ -170,7 +154,7 @@ Specifying the step size and inverse mass matrix is cumbersome. We can use Stan' The adaptation algorithm takes a function that returns a transition kernel given a step size and an inverse mass matrix: -```{code-cell} +```{code-cell} ipython3 %%time warmup = blackjax.window_adaptation(blackjax.nuts, logdensity) @@ -180,17 +164,17 @@ rng_key, warmup_key, sample_key = jax.random.split(rng_key, 3) We can use the obtained parameters to define a new kernel. Note that we do not have to use the same kernel that was used for the adaptation: -```{code-cell} +```{code-cell} ipython3 %%time -kernel = blackjax.nuts(logdensity, **parameters).step -states = inference_loop(sample_key, kernel, state, 1_000) +nuts = blackjax.nuts(logdensity, **parameters) +final_state, state_history, info_history = run_inference_algorithm(rng_key, state, nuts, 1_000) -mcmc_samples = states.position +mcmc_samples = state_history.position mcmc_samples["scale"] = jnp.exp(mcmc_samples["log_scale"]).block_until_ready() ``` -```{code-cell} +```{code-cell} ipython3 :tags: [hide-input] fig, (ax, ax1) = plt.subplots(ncols=2, figsize=(15, 6)) From 371540e0c35b1de3e314b7dcd57380ce4e30f537 Mon Sep 17 00:00:00 2001 From: Paul Scemama Date: Fri, 1 Dec 2023 23:25:49 +0000 Subject: [PATCH 2/8] Revert back quickstart.md --- docs/examples/quickstart.md | 64 +++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/docs/examples/quickstart.md b/docs/examples/quickstart.md index f70241c85..870e5df9a 100644 --- a/docs/examples/quickstart.md +++ b/docs/examples/quickstart.md @@ -17,7 +17,7 @@ BlackJAX is an MCMC sampling library based on [JAX](https://github.com/google/ja In this notebook we provide a simple example based on basic Hamiltonian Monte Carlo and the NUTS algorithm to showcase the architecture and interfaces in the library -```{code-cell} ipython3 +```{code-cell} import matplotlib.pyplot as plt import numpy as np @@ -26,10 +26,9 @@ import jax.numpy as jnp import jax.scipy.stats as stats import blackjax -from blackjax.util import run_inference_algorithm ``` -```{code-cell} ipython3 +```{code-cell} :tags: [remove-output] from datetime import date @@ -40,12 +39,12 @@ rng_key = jax.random.key(int(date.today().strftime("%Y%m%d"))) We'll generate observations from a normal distribution of known `loc` and `scale` to see if we can recover the parameters in sampling. **MCMC algorithms usually assume samples are being drawn from an unconstrained Euclidean space.** Hence why we'll log transform the scale parameter, so that sampling is done on the real line. Samples can be transformed back to their original space in post-processing. Let's take a decent-size dataset with 1,000 points: -```{code-cell} ipython3 +```{code-cell} loc, scale = 10, 20 observed = np.random.normal(loc, scale, size=1_000) ``` -```{code-cell} ipython3 +```{code-cell} def logdensity_fn(loc, log_scale, observed=observed): """Univariate Normal""" scale = jnp.exp(log_scale) @@ -60,7 +59,7 @@ logdensity = lambda x: logdensity_fn(**x) ### Sampler Parameters -```{code-cell} ipython3 +```{code-cell} inv_mass_matrix = np.array([0.5, 0.01]) num_integration_steps = 60 step_size = 1e-3 @@ -72,7 +71,7 @@ hmc = blackjax.hmc(logdensity, step_size, inv_mass_matrix, num_integration_steps The initial state of the HMC algorithm requires not only an initial position, but also the potential energy and gradient of the potential energy at this position (for example, in the context of Bayesian modeling, the output of the log posterior function evaluated at the initial position). BlackJAX provides a `new_state` function to initialize the state from an initial position. -```{code-cell} ipython3 +```{code-cell} initial_position = {"loc": 1.0, "log_scale": 1.0} initial_state = hmc.init(initial_position) initial_state @@ -83,20 +82,37 @@ initial_state The HMC kernel is easy to obtain: -+++ +```{code-cell} +hmc_kernel = jax.jit(hmc.step) +``` + +BlackJAX does not provide a default inference loop, but it easy to implement with JAX's `lax.scan`: + +```{code-cell} +def inference_loop(rng_key, kernel, initial_state, num_samples): + @jax.jit + def one_step(state, rng_key): + state, _ = kernel(rng_key, state) + return state, state + + keys = jax.random.split(rng_key, num_samples) + _, states = jax.lax.scan(one_step, initial_state, keys) + + return states +``` ### Inference -```{code-cell} ipython3 +```{code-cell} %%time rng_key, sample_key = jax.random.split(rng_key) -final_state, state_history, info_history = run_inference_algorithm(rng_key, initial_state, hmc, 10_000) +states = inference_loop(sample_key, hmc_kernel, initial_state, 10_000) -mcmc_samples = state_history.position +mcmc_samples = states.position mcmc_samples["scale"] = jnp.exp(mcmc_samples["log_scale"]).block_until_ready() ``` -```{code-cell} ipython3 +```{code-cell} :tags: [hide-input] fig, (ax, ax1) = plt.subplots(ncols=2, figsize=(15, 6)) @@ -113,29 +129,29 @@ ax1.set_ylabel("scale"); NUTS is a *dynamic* algorithm: the number of integration steps is determined at runtime. We still need to specify a step size and a mass matrix: -```{code-cell} ipython3 +```{code-cell} inv_mass_matrix = np.array([0.5, 0.01]) step_size = 1e-3 nuts = blackjax.nuts(logdensity, step_size, inv_mass_matrix) ``` -```{code-cell} ipython3 +```{code-cell} initial_position = {"loc": 1.0, "log_scale": 1.0} initial_state = nuts.init(initial_position) initial_state ``` -```{code-cell} ipython3 +```{code-cell} %%time rng_key, sample_key = jax.random.split(rng_key) -final_state, state_history, info_history = run_inference_algorithm(rng_key, initial_state, nuts, 4_000) +states = inference_loop(sample_key, nuts.step, initial_state, 4_000) -mcmc_samples = state_history.position +mcmc_samples = states.position mcmc_samples["scale"] = jnp.exp(mcmc_samples["log_scale"]).block_until_ready() ``` -```{code-cell} ipython3 +```{code-cell} :tags: [hide-input] fig, (ax, ax1) = plt.subplots(ncols=2, figsize=(15, 6)) @@ -154,7 +170,7 @@ Specifying the step size and inverse mass matrix is cumbersome. We can use Stan' The adaptation algorithm takes a function that returns a transition kernel given a step size and an inverse mass matrix: -```{code-cell} ipython3 +```{code-cell} %%time warmup = blackjax.window_adaptation(blackjax.nuts, logdensity) @@ -164,17 +180,17 @@ rng_key, warmup_key, sample_key = jax.random.split(rng_key, 3) We can use the obtained parameters to define a new kernel. Note that we do not have to use the same kernel that was used for the adaptation: -```{code-cell} ipython3 +```{code-cell} %%time -nuts = blackjax.nuts(logdensity, **parameters) -final_state, state_history, info_history = run_inference_algorithm(rng_key, state, nuts, 1_000) +kernel = blackjax.nuts(logdensity, **parameters).step +states = inference_loop(sample_key, kernel, state, 1_000) -mcmc_samples = state_history.position +mcmc_samples = states.position mcmc_samples["scale"] = jnp.exp(mcmc_samples["log_scale"]).block_until_ready() ``` -```{code-cell} ipython3 +```{code-cell} :tags: [hide-input] fig, (ax, ax1) = plt.subplots(ncols=2, figsize=(15, 6)) From bc54b9a62332597b5b1ecf215272812e3e1fd9b8 Mon Sep 17 00:00:00 2001 From: Paul Scemama Date: Sun, 3 Dec 2023 20:44:46 +0000 Subject: [PATCH 3/8] Add run_inference wrapper to tests/mcmc/sampling; Get rid of arg types for run_inference wrapper --- blackjax/util.py | 15 +++---- tests/mcmc/test_sampling.py | 78 ++++++++++++++++++++----------------- 2 files changed, 51 insertions(+), 42 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index af334a1a3..e5d510c32 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -1,6 +1,6 @@ """Utility functions for BlackJax.""" from functools import partial -from typing import Tuple, Union +from typing import Union import jax.numpy as jnp from jax import jit, lax @@ -140,11 +140,11 @@ def index_pytree(input_pytree: ArrayLikeTree) -> ArrayTree: def run_inference_algorithm( - rng_key: PRNGKey, - initial_state_or_position: ArrayLikeTree, - inference_algorithm: Union[SamplingAlgorithm, VIAlgorithm], - num_steps: int, -) -> Tuple[State, State, Info]: + rng_key, + initial_state_or_position, + inference_algorithm, + num_steps, +) -> tuple[State, State, Info]: """Wrapper to run an inference algorithm. Parameters @@ -169,8 +169,9 @@ def run_inference_algorithm( try: initial_state = inference_algorithm.init(initial_state_or_position) except TypeError: - # initial_state is already in the right format. + # We assume initial_state is already in the right format. initial_state = initial_state_or_position + initial_state = initial_state_or_position keys = split(rng_key, num_steps) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index db9ef9944..a09355aba 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -13,17 +13,7 @@ import blackjax import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk - - -def inference_loop(kernel, num_samples, rng_key, initial_state): - def one_step(state, rng_key): - state, _ = kernel(rng_key, state) - return state, state - - keys = jax.random.split(rng_key, num_samples) - _, states = jax.lax.scan(one_step, initial_state, keys) - - return states +import blackjax.util as util def orbit_samples(orbits, weights, rng_key): @@ -154,8 +144,8 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal): ) algorithm = case["algorithm"](logposterior_fn, **parameters) - states = inference_loop( - algorithm.step, case["num_sampling_steps"], inference_key, state + _, states, _ = util.run_inference_algorithm( + inference_key, state, algorithm, case["num_sampling_steps"] ) coefs_samples = states.position["coefs"] @@ -177,7 +167,7 @@ def test_mala(self): mala = blackjax.mala(logposterior_fn, 1e-5) state = mala.init({"coefs": 1.0, "log_scale": 1.0}) - states = inference_loop(mala.step, 10_000, inference_key, state) + _, states, _ = util.run_inference_algorithm(inference_key, state, mala, 10_000) coefs_samples = states.position["coefs"][3000:] scale_samples = np.exp(states.position["log_scale"][3000:]) @@ -240,9 +230,11 @@ def test_pathfinder_adaptation( initial_position, num_warmup_steps, ) - kernel = algorithm(logposterior_fn, **parameters).step + kernel = algorithm(logposterior_fn, **parameters) - states = inference_loop(kernel, num_sampling_steps, inference_key, state) + _, states, _ = util.run_inference_algorithm( + inference_key, state, kernel, num_sampling_steps + ) coefs_samples = states.position["coefs"] scale_samples = np.exp(states.position["log_scale"]) @@ -277,12 +269,12 @@ def test_meads(self): initial_positions, num_steps=1000, ) - kernel = blackjax.ghmc(logposterior_fn, **parameters).step + kernel = blackjax.ghmc(logposterior_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - states = jax.vmap(lambda key, state: inference_loop(kernel, 100, key, state))( - chain_keys, last_states - ) + _, states, _ = jax.vmap( + lambda key, state: util.run_inference_algorithm(key, state, kernel, 100) + )(chain_keys, last_states) coefs_samples = states.position["coefs"] scale_samples = np.exp(states.position["log_scale"]) @@ -319,12 +311,12 @@ def test_chees(self, jitter_generator): optim=optax.adam(learning_rate=0.1), num_steps=1000, ) - kernel = blackjax.dynamic_hmc(logposterior_fn, **parameters).step + kernel = blackjax.dynamic_hmc(logposterior_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - states = jax.vmap(lambda key, state: inference_loop(kernel, 100, key, state))( - chain_keys, last_states - ) + _, states, _ = jax.vmap( + lambda key, state: util.run_inference_algorithm(key, state, kernel, 100) + )(chain_keys, last_states) coefs_samples = states.position["coefs"] scale_samples = np.exp(states.position["log_scale"]) @@ -519,13 +511,19 @@ def setUp(self): def test_latent_gaussian(self): from blackjax import mgrad_gaussian - init, step = mgrad_gaussian(lambda x: -0.5 * jnp.sum((x - 1.0) ** 2), self.C) + algorithm = mgrad_gaussian(lambda x: -0.5 * jnp.sum((x - 1.0) ** 2), self.C) + algorithm = algorithm._replace( + step=functools.partial(algorithm.step, delta=self.delta) + ) - kernel = lambda key, x: step(key, x, self.delta) - initial_state = init(jnp.zeros((1,))) + initial_state = algorithm.init(jnp.zeros((1,))) - states = self.variant( - functools.partial(inference_loop, kernel, self.sampling_steps), + _, states, _ = self.variant( + functools.partial( + util.run_inference_algorithm, + inference_algorithm=algorithm, + num_steps=self.sampling_steps, + ), )(self.key, initial_state) np.testing.assert_allclose( @@ -650,7 +648,7 @@ def test_univariate_normal( algo = algorithm(self.normal_logprob, **parameters) rng_key = self.key if algorithm == blackjax.elliptical_slice: - algo = algorithm(lambda _: 1.0, **parameters) + algo = algorithm(lambda x: jnp.ones_like(x), **parameters) if algorithm == blackjax.ghmc: rng_key, initial_state_key = jax.random.split(rng_key) initial_state = algo.init(initial_position, initial_state_key) @@ -658,9 +656,13 @@ def test_univariate_normal( initial_state = algo.init(initial_position) inference_key, orbit_key = jax.random.split(rng_key) - kernel = algo.step - states = self.variant( - functools.partial(inference_loop, kernel, num_sampling_steps) + kernel = algo + _, states, _ = self.variant( + functools.partial( + util.run_inference_algorithm, + inference_algorithm=kernel, + num_steps=num_sampling_steps, + ) )(inference_key, initial_state) if algorithm == blackjax.orbital_hmc: @@ -777,9 +779,15 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): multi_chain_sample_key = jax.random.split(sample_key, num_chains) inference_loop_multiple_chains = jax.vmap( - functools.partial(inference_loop, kernel.step, 2_000) + functools.partial( + util.run_inference_algorithm, + inference_algorithm=kernel, + num_steps=2_000, + ) + ) + _, states, _ = inference_loop_multiple_chains( + multi_chain_sample_key, initial_states ) - states = inference_loop_multiple_chains(multi_chain_sample_key, initial_states) posterior_samples = states.position[:, -1000:] posterior_delta = posterior_samples - true_loc From 3dd6d3c3cc1557224f8033a4b154d37ab6cbf2ac Mon Sep 17 00:00:00 2001 From: Paul Scemama Date: Sun, 3 Dec 2023 20:46:19 +0000 Subject: [PATCH 4/8] Get rid of unused imports --- blackjax/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/util.py b/blackjax/util.py index e5d510c32..f667d147a 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -8,7 +8,7 @@ from jax.random import normal, split from jax.tree_util import tree_leaves -from blackjax.base import Info, SamplingAlgorithm, State, VIAlgorithm +from blackjax.base import Info, State from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey From dd7dabd56e54484a725bbb2489c3b60a6d5ae8d8 Mon Sep 17 00:00:00 2001 From: Paul Scemama Date: Sun, 3 Dec 2023 20:55:18 +0000 Subject: [PATCH 5/8] Add run_inference wrapper to tests/benchmark --- tests/test_benchmarks.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index d64efa4cd..c5ef60c22 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -13,6 +13,7 @@ import pytest import blackjax +import blackjax.util as util def regression_logprob(log_scale, coefs, preds, x): @@ -25,18 +26,6 @@ def regression_logprob(log_scale, coefs, preds, x): return sum(x.sum() for x in [scale_prior, coefs_prior, logpdf]) -def inference_loop(kernel, num_samples, rng_key, initial_state): - @jax.jit - def one_step(state, rng_key): - state, _ = kernel(rng_key, state) - return state, state - - keys = jax.random.split(rng_key, num_samples) - _, states = jax.lax.scan(one_step, initial_state, keys) - - return states - - def run_regression(algorithm, **parameters): key = jax.random.key(0) rng_key, init_key0, init_key1 = jax.random.split(key, 3) @@ -57,9 +46,9 @@ def run_regression(algorithm, **parameters): (state, parameters), _ = warmup.run( warmup_key, {"log_scale": 0.0, "coefs": 2.0}, 1000 ) - kernel = algorithm(logdensity_fn, **parameters).step + kernel = algorithm(logdensity_fn, **parameters) - states = inference_loop(kernel, 10_000, inference_key, state) + _, states, _ = util.run_inference_algorithm(inference_key, state, kernel, 10_000) return states From 8df10e3235471b21b05817a99323c26852778484 Mon Sep 17 00:00:00 2001 From: Paul Scemama Date: Wed, 6 Dec 2023 00:17:50 +0000 Subject: [PATCH 6/8] Change import style --- tests/mcmc/test_sampling.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index a09355aba..55edd5795 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -13,7 +13,7 @@ import blackjax import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk -import blackjax.util as util +from blackjax.util import run_inference_algorithm def orbit_samples(orbits, weights, rng_key): @@ -144,7 +144,7 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal): ) algorithm = case["algorithm"](logposterior_fn, **parameters) - _, states, _ = util.run_inference_algorithm( + _, states, _ = run_inference_algorithm( inference_key, state, algorithm, case["num_sampling_steps"] ) @@ -167,7 +167,7 @@ def test_mala(self): mala = blackjax.mala(logposterior_fn, 1e-5) state = mala.init({"coefs": 1.0, "log_scale": 1.0}) - _, states, _ = util.run_inference_algorithm(inference_key, state, mala, 10_000) + _, states, _ = run_inference_algorithm(inference_key, state, mala, 10_000) coefs_samples = states.position["coefs"][3000:] scale_samples = np.exp(states.position["log_scale"][3000:]) @@ -232,7 +232,7 @@ def test_pathfinder_adaptation( ) kernel = algorithm(logposterior_fn, **parameters) - _, states, _ = util.run_inference_algorithm( + _, states, _ = run_inference_algorithm( inference_key, state, kernel, num_sampling_steps ) @@ -273,7 +273,7 @@ def test_meads(self): chain_keys = jax.random.split(inference_key, num_chains) _, states, _ = jax.vmap( - lambda key, state: util.run_inference_algorithm(key, state, kernel, 100) + lambda key, state: run_inference_algorithm(key, state, kernel, 100) )(chain_keys, last_states) coefs_samples = states.position["coefs"] @@ -315,7 +315,7 @@ def test_chees(self, jitter_generator): chain_keys = jax.random.split(inference_key, num_chains) _, states, _ = jax.vmap( - lambda key, state: util.run_inference_algorithm(key, state, kernel, 100) + lambda key, state: run_inference_algorithm(key, state, kernel, 100) )(chain_keys, last_states) coefs_samples = states.position["coefs"] @@ -337,7 +337,9 @@ def test_barker(self): barker = blackjax.barker_proposal(logposterior_fn, 1e-1) state = barker.init({"coefs": 1.0, "log_scale": 1.0}) - states = inference_loop(barker.step, 10_000, inference_key, state) + + _, states, _ = run_inference_algorithm(inference_key, state, barker, 10_000) + # states = inference_loop(barker.step, 10_000, inference_key, state) coefs_samples = states.position["coefs"][3000:] scale_samples = np.exp(states.position["log_scale"][3000:]) @@ -520,7 +522,7 @@ def test_latent_gaussian(self): _, states, _ = self.variant( functools.partial( - util.run_inference_algorithm, + run_inference_algorithm, inference_algorithm=algorithm, num_steps=self.sampling_steps, ), @@ -659,7 +661,7 @@ def test_univariate_normal( kernel = algo _, states, _ = self.variant( functools.partial( - util.run_inference_algorithm, + run_inference_algorithm, inference_algorithm=kernel, num_steps=num_sampling_steps, ) @@ -780,7 +782,7 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): inference_loop_multiple_chains = jax.vmap( functools.partial( - util.run_inference_algorithm, + run_inference_algorithm, inference_algorithm=kernel, num_steps=2_000, ) From bbbb435bce0d1ffc07def07536314c61cdffeba8 Mon Sep 17 00:00:00 2001 From: Paul Scemama Date: Wed, 6 Dec 2023 00:56:55 +0000 Subject: [PATCH 7/8] Import style for benchmarks test; add wrapper to adaptation test --- tests/adaptation/test_adaptation.py | 14 ++++++-------- tests/test_benchmarks.py | 4 ++-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/adaptation/test_adaptation.py b/tests/adaptation/test_adaptation.py index 1b95b0115..d189d2e6b 100644 --- a/tests/adaptation/test_adaptation.py +++ b/tests/adaptation/test_adaptation.py @@ -6,6 +6,7 @@ import blackjax from blackjax.adaptation import window_adaptation +from blackjax.util import run_inference_algorithm @pytest.mark.parametrize( @@ -57,15 +58,12 @@ def test_chees_adaptation(): optim=optax.adamw(learning_rate=0.5), num_steps=num_burnin_steps, ) - kernel = blackjax.dynamic_hmc(logprob_fn, **parameters).step + kernel = blackjax.dynamic_hmc(logprob_fn, **parameters) - def one_step(states, rng_key): - keys = jax.random.split(rng_key, num_chains) - states, infos = jax.vmap(kernel)(keys, states) - return states, infos - - keys = jax.random.split(inference_key, num_results) - _, infos = jax.lax.scan(one_step, last_states, keys) + chain_keys = jax.random.split(inference_key, num_chains) + _, _, infos = jax.vmap( + lambda key, state: run_inference_algorithm(key, state, kernel, num_results) + )(chain_keys, last_states) harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate) np.testing.assert_allclose(harmonic_mean, 0.75, rtol=1e-1) diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index c5ef60c22..92a173acb 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -13,7 +13,7 @@ import pytest import blackjax -import blackjax.util as util +from blackjax.util import run_inference_algorithm def regression_logprob(log_scale, coefs, preds, x): @@ -48,7 +48,7 @@ def run_regression(algorithm, **parameters): ) kernel = algorithm(logdensity_fn, **parameters) - _, states, _ = util.run_inference_algorithm(inference_key, state, kernel, 10_000) + _, states, _ = run_inference_algorithm(inference_key, state, kernel, 10_000) return states From 11ac8ef7197cb177ca1db523d142fdec56578407 Mon Sep 17 00:00:00 2001 From: Paul Scemama Date: Wed, 6 Dec 2023 01:32:48 +0000 Subject: [PATCH 8/8] Replace 'kernel' variable name with 'inference algorithm' when using wrapper run_inference_algorithm --- tests/adaptation/test_adaptation.py | 4 +- tests/mcmc/test_sampling.py | 59 +++++++++++++++++------------ tests/test_benchmarks.py | 6 ++- 3 files changed, 41 insertions(+), 28 deletions(-) diff --git a/tests/adaptation/test_adaptation.py b/tests/adaptation/test_adaptation.py index d189d2e6b..93bf418d2 100644 --- a/tests/adaptation/test_adaptation.py +++ b/tests/adaptation/test_adaptation.py @@ -58,11 +58,11 @@ def test_chees_adaptation(): optim=optax.adamw(learning_rate=0.5), num_steps=num_burnin_steps, ) - kernel = blackjax.dynamic_hmc(logprob_fn, **parameters) + algorithm = blackjax.dynamic_hmc(logprob_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) _, _, infos = jax.vmap( - lambda key, state: run_inference_algorithm(key, state, kernel, num_results) + lambda key, state: run_inference_algorithm(key, state, algorithm, num_results) )(chain_keys, last_states) harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 55edd5795..ec47f1180 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -142,10 +142,10 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal): case["initial_position"], case["num_warmup_steps"], ) - algorithm = case["algorithm"](logposterior_fn, **parameters) + inference_algorithm = case["algorithm"](logposterior_fn, **parameters) _, states, _ = run_inference_algorithm( - inference_key, state, algorithm, case["num_sampling_steps"] + inference_key, state, inference_algorithm, case["num_sampling_steps"] ) coefs_samples = states.position["coefs"] @@ -230,10 +230,10 @@ def test_pathfinder_adaptation( initial_position, num_warmup_steps, ) - kernel = algorithm(logposterior_fn, **parameters) + inference_algorithm = algorithm(logposterior_fn, **parameters) _, states, _ = run_inference_algorithm( - inference_key, state, kernel, num_sampling_steps + inference_key, state, inference_algorithm, num_sampling_steps ) coefs_samples = states.position["coefs"] @@ -269,11 +269,13 @@ def test_meads(self): initial_positions, num_steps=1000, ) - kernel = blackjax.ghmc(logposterior_fn, **parameters) + inference_algorithm = blackjax.ghmc(logposterior_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) _, states, _ = jax.vmap( - lambda key, state: run_inference_algorithm(key, state, kernel, 100) + lambda key, state: run_inference_algorithm( + key, state, inference_algorithm, 100 + ) )(chain_keys, last_states) coefs_samples = states.position["coefs"] @@ -311,11 +313,13 @@ def test_chees(self, jitter_generator): optim=optax.adam(learning_rate=0.1), num_steps=1000, ) - kernel = blackjax.dynamic_hmc(logposterior_fn, **parameters) + inference_algorithm = blackjax.dynamic_hmc(logposterior_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) _, states, _ = jax.vmap( - lambda key, state: run_inference_algorithm(key, state, kernel, 100) + lambda key, state: run_inference_algorithm( + key, state, inference_algorithm, 100 + ) )(chain_keys, last_states) coefs_samples = states.position["coefs"] @@ -339,7 +343,6 @@ def test_barker(self): state = barker.init({"coefs": 1.0, "log_scale": 1.0}) _, states, _ = run_inference_algorithm(inference_key, state, barker, 10_000) - # states = inference_loop(barker.step, 10_000, inference_key, state) coefs_samples = states.position["coefs"][3000:] scale_samples = np.exp(states.position["log_scale"][3000:]) @@ -513,17 +516,22 @@ def setUp(self): def test_latent_gaussian(self): from blackjax import mgrad_gaussian - algorithm = mgrad_gaussian(lambda x: -0.5 * jnp.sum((x - 1.0) ** 2), self.C) - algorithm = algorithm._replace( - step=functools.partial(algorithm.step, delta=self.delta) + inference_algorithm = mgrad_gaussian( + lambda x: -0.5 * jnp.sum((x - 1.0) ** 2), self.C + ) + inference_algorithm = inference_algorithm._replace( + step=functools.partial( + inference_algorithm.step, + delta=self.delta, + ) ) - initial_state = algorithm.init(jnp.zeros((1,))) + initial_state = inference_algorithm.init(jnp.zeros((1,))) _, states, _ = self.variant( functools.partial( run_inference_algorithm, - inference_algorithm=algorithm, + inference_algorithm=inference_algorithm, num_steps=self.sampling_steps, ), )(self.key, initial_state) @@ -647,22 +655,23 @@ def test_univariate_normal( if algorithm == blackjax.rmh: parameters["proposal_generator"] = rmh_proposal_distribution - algo = algorithm(self.normal_logprob, **parameters) + inference_algorithm = algorithm(self.normal_logprob, **parameters) rng_key = self.key if algorithm == blackjax.elliptical_slice: - algo = algorithm(lambda x: jnp.ones_like(x), **parameters) + inference_algorithm = algorithm(lambda x: jnp.ones_like(x), **parameters) if algorithm == blackjax.ghmc: rng_key, initial_state_key = jax.random.split(rng_key) - initial_state = algo.init(initial_position, initial_state_key) + initial_state = inference_algorithm.init( + initial_position, initial_state_key + ) else: - initial_state = algo.init(initial_position) + initial_state = inference_algorithm.init(initial_position) inference_key, orbit_key = jax.random.split(rng_key) - kernel = algo _, states, _ = self.variant( functools.partial( run_inference_algorithm, - inference_algorithm=kernel, + inference_algorithm=inference_algorithm, num_steps=num_sampling_steps, ) )(inference_key, initial_state) @@ -767,23 +776,25 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): inverse_mass_matrix = true_scale**2 else: inverse_mass_matrix = true_cov - kernel = algorithm( + inference_algorithm = algorithm( logdensity_fn, inverse_mass_matrix=inverse_mass_matrix, **parameters, ) else: - kernel = algorithm(logdensity_fn, **parameters) + inference_algorithm = algorithm(logdensity_fn, **parameters) num_chains = 10 initial_positions = jax.random.normal(pos_init_key, [num_chains, 2]) - initial_states = jax.vmap(kernel.init, in_axes=(0,))(initial_positions) + initial_states = jax.vmap(inference_algorithm.init, in_axes=(0,))( + initial_positions + ) multi_chain_sample_key = jax.random.split(sample_key, num_chains) inference_loop_multiple_chains = jax.vmap( functools.partial( run_inference_algorithm, - inference_algorithm=kernel, + inference_algorithm=inference_algorithm, num_steps=2_000, ) ) diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index 92a173acb..d8f09cea0 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -46,9 +46,11 @@ def run_regression(algorithm, **parameters): (state, parameters), _ = warmup.run( warmup_key, {"log_scale": 0.0, "coefs": 2.0}, 1000 ) - kernel = algorithm(logdensity_fn, **parameters) + inference_algorithm = algorithm(logdensity_fn, **parameters) - _, states, _ = run_inference_algorithm(inference_key, state, kernel, 10_000) + _, states, _ = run_inference_algorithm( + inference_key, state, inference_algorithm, 10_000 + ) return states