Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify run_inference_algorithm #714

Merged
merged 21 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.flatten_util import ravel_pytree

from blackjax.diagnostics import effective_sample_size
from blackjax.util import pytree_size, streaming_average_update
from blackjax.util import incremental_value_update, pytree_size


class MCLMCAdaptationState(NamedTuple):
Expand Down Expand Up @@ -199,9 +199,9 @@ def step(iteration_state, weight_and_key):

x = ravel_pytree(state.position)[0]
# update the running average of x, x^2
streaming_avg = streaming_average_update(
current_value=jnp.array([x, jnp.square(x)]),
previous_weight_and_average=streaming_avg,
streaming_avg = incremental_value_update(
expectation=jnp.array([x, jnp.square(x)]),
incremental_val=streaming_avg,
weight=(1 - mask) * success * params.step_size,
zero_prevention=mask,
)
Expand Down
163 changes: 102 additions & 61 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from functools import partial
from typing import Callable, Union

import jax
import jax.numpy as jnp
from jax import jit, lax
from jax.flatten_util import ravel_pytree
Expand Down Expand Up @@ -149,9 +148,7 @@
initial_state: ArrayLikeTree = None,
initial_position: ArrayLikeTree = None,
progress_bar: bool = False,
transform: Callable = lambda x: x,
return_state_history=True,
expectation: Callable = lambda x: x,
transform: Callable = lambda state, info: (state, info),
) -> tuple:
"""Wrapper to run an inference algorithm.

Expand All @@ -166,104 +163,148 @@
initial_state
The initial state of the inference algorithm.
initial_position
The initial position of the inference algorithm. This is used when the initial
state is not provided.
The initial position of the inference algorithm. This is used when the initial state is not provided.
inference_algorithm
One of blackjax's sampling algorithms or variational inference algorithms.
num_steps
Number of MCMC steps.
progress_bar
Whether to display a progress bar.
transform
A transformation of the trace of states to be returned. This is useful for
A transformation of the trace of states (and info) to be returned. This is useful for
computing determinstic variables, or returning a subset of the states.
By default, the states are returned as is.
expectation
A function that computes the expectation of the state. This is done
incrementally, so doesn't require storing all the states.
return_state_history
if False, `run_inference_algorithm` will only return an expectation of the value
of transform, and return that average instead of the full set of samples. This
is useful when memory is a bottleneck.

Returns
-------
If return_state_history is True:
1. The final state.
2. The trace of the state.
2. The trace of the transform(state)
3. The trace of the info of the inference algorithm for diagnostics.
If return_state_history is False:
1. This is the expectation of state over the chain. Otherwise the final state.
2. The final state of the inference algorithm.
"""

if initial_state is None and initial_position is None:
raise ValueError(
"Either `initial_state` or `initial_position` must be provided."
)
raise ValueError("Either initial_state or initial_position must be provided.")

Check warning on line 186 in blackjax/util.py

View check run for this annotation

Codecov / codecov/patch

blackjax/util.py#L186

Added line #L186 was not covered by tests
if initial_state is not None and initial_position is not None:
raise ValueError(
"Only one of `initial_state` or `initial_position` must be provided."
"Only one of initial_state or initial_position must be provided."
)

if initial_state is None:
rng_key, init_key = split(rng_key, 2)
rng_key, init_key = split(rng_key, 2)
if initial_position is not None:
junpenglao marked this conversation as resolved.
Show resolved Hide resolved
initial_state = inference_algorithm.init(initial_position, init_key)

keys = split(rng_key, num_steps)

def one_step(average_and_state, xs, return_state):
def one_step(state, xs):
_, rng_key = xs
average, state = average_and_state
state, info = inference_algorithm.step(rng_key, state)
average = streaming_average_update(expectation(transform(state)), average)
if return_state:
return (average, state), (transform(state), info)
else:
return (average, state), None
return state, transform(state, info)

one_step = jax.jit(partial(one_step, return_state=return_state_history))

xs = (jnp.arange(num_steps), keys)
scan_fn = gen_scan_fn(num_steps, progress_bar)
((_, average), final_state), history = scan_fn(
one_step,
((0, expectation(transform(initial_state))), initial_state),
xs,
)

if not return_state_history:
return average, transform(final_state)
else:
state_history, info_history = history
return transform(final_state), state_history, info_history
xs = jnp.arange(num_steps), keys
final_state, history = scan_fn(one_step, initial_state, xs)

return final_state, history


def streaming_average_update(
current_value, previous_weight_and_average, weight=1.0, zero_prevention=0.0
def store_only_expectation_values(
sampling_algorithm,
state_transform=lambda x: x,
incremental_value_transform=lambda x: x,
burn_in=0,
):
"""Takes a sampling algorithm and constructs from it a new sampling algorithm object. The new sampling algorithm has the same
kernel but only stores the streaming expectation values of some observables, not the full states; to save memory.

It saves incremental_value_transform(E[state_transform(x)]) at each step i, where expectation is computed with samples up to i-th sample.

Example:

.. code::

init_key, state_key, run_key = jax.random.split(jax.random.PRNGKey(0),3)
model = StandardNormal(2)
initial_position = model.sample_init(init_key)
initial_state = blackjax.mcmc.mclmc.init(
position=initial_position, logdensity_fn=model.logdensity_fn, rng_key=state_key
)
integrator_type = "mclachlan"
L = 1.0
step_size = 0.1
num_steps = 4

integrator = map_integrator_type_to_integrator['mclmc'][integrator_type]
state_transform = lambda state: state.position
memory_efficient_sampling_alg, transform = store_only_expectation_values(
sampling_algorithm=sampling_alg,
state_transform=state_transform)

initial_state = memory_efficient_sampling_alg.init(initial_state)

final_state, trace_at_every_step = run_inference_algorithm(

rng_key=run_key,
initial_state=initial_state,
inference_algorithm=memory_efficient_sampling_alg,
num_steps=num_steps,
transform=transform,
progress_bar=True,
)
"""

def init_fn(state):
averaging_state = (0.0, state_transform(state))
return (state, averaging_state)

def update_fn(rng_key, state_and_incremental_val):
state, averaging_state = state_and_incremental_val
state, info = sampling_algorithm.step(
rng_key, state
) # update the state with the sampling algorithm
averaging_state = incremental_value_update(
state_transform(state),
averaging_state,
weight=(
averaging_state[0] >= burn_in
), # If we want to eliminate some number of steps as a burn-in
zero_prevention=1e-10 * (burn_in > 0),
)
# update the expectation value with the running average
return (state, averaging_state), info

def transform(state_and_incremental_val, info):
(state, (_, incremental_value)) = state_and_incremental_val
return incremental_value_transform(incremental_value), info

return SamplingAlgorithm(init_fn, update_fn), transform


def incremental_value_update(
expectation, incremental_val, weight=1.0, zero_prevention=0.0
):
"""Compute the streaming average of a function O(x) using a weight.
Parameters:
----------
current_value
the current value of the function that we want to take average of
previous_weight_and_average
tuple of (previous_weight, previous_average) where previous_weight is the
sum of weights and average is the current estimated average
expectation
the value of the expectation at the current timestep
incremental_val
tuple of (total, average) where total is the sum of weights and average is the current average
weight
weight of the current state
zero_prevention
small value to prevent division by zero
Returns:
----------
new total weight and streaming average
new streaming average
"""
previous_weight, previous_average = previous_weight_and_average
current_weight = previous_weight + weight
current_average = jax.tree.map(
lambda x, avg: (previous_weight * avg + weight * x)
/ (current_weight + zero_prevention),
current_value,
previous_average,

flat_expectation, unravel_fn = ravel_pytree(expectation)
total, average = incremental_val
flat_average, _ = ravel_pytree(average)
average = (total * flat_average + weight * flat_expectation) / (
total + weight + zero_prevention
)
return current_weight, current_average
total += weight
incremental_val = (total, unravel_fn(average))
return incremental_val
junpenglao marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion tests/adaptation/test_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_chees_adaptation(adaptation_filters):
algorithm = blackjax.dynamic_hmc(logprob_fn, **parameters)

chain_keys = jax.random.split(inference_key, num_chains)
_, _, infos = jax.vmap(
_, (_, infos) = jax.vmap(
lambda key, state: run_inference_algorithm(
rng_key=key,
initial_state=state,
Expand Down
31 changes: 19 additions & 12 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ def run_mclmc(
sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov,
)

_, samples, _ = run_inference_algorithm(
_, samples = run_inference_algorithm(
rng_key=run_key,
initial_state=blackjax_state_after_tuning,
inference_algorithm=sampling_alg,
num_steps=num_steps,
transform=lambda x: x.position,
transform=lambda state, info: state.position,
)

return samples
Expand Down Expand Up @@ -197,7 +197,7 @@ def check_attrs(attribute, keyset):
for i, attribute in enumerate(["state", "info", "adaptation_state"]):
check_attrs(attribute, keysets[i])

_, states, _ = run_inference_algorithm(
_, (states, _) = run_inference_algorithm(
rng_key=inference_key,
initial_state=state,
inference_algorithm=inference_algorithm,
Expand All @@ -223,10 +223,11 @@ def test_mala(self):

mala = blackjax.mala(logposterior_fn, 1e-5)
state = mala.init({"coefs": 1.0, "log_scale": 1.0})
_, states, _ = run_inference_algorithm(
_, states = run_inference_algorithm(
rng_key=inference_key,
initial_state=state,
inference_algorithm=mala,
transform=lambda state, info: state,
junpenglao marked this conversation as resolved.
Show resolved Hide resolved
num_steps=10_000,
)

Expand Down Expand Up @@ -375,11 +376,12 @@ def test_pathfinder_adaptation(
)
inference_algorithm = algorithm(logposterior_fn, **parameters)

_, states, _ = run_inference_algorithm(
_, states = run_inference_algorithm(
rng_key=inference_key,
initial_state=state,
inference_algorithm=inference_algorithm,
num_steps=num_sampling_steps,
transform=lambda state, info: state,
)

coefs_samples = states.position["coefs"]
Expand Down Expand Up @@ -418,11 +420,12 @@ def test_meads(self):
inference_algorithm = blackjax.ghmc(logposterior_fn, **parameters)

chain_keys = jax.random.split(inference_key, num_chains)
_, states, _ = jax.vmap(
_, states = jax.vmap(
lambda key, state: run_inference_algorithm(
rng_key=key,
initial_state=state,
inference_algorithm=inference_algorithm,
transform=lambda state, info: state,
num_steps=100,
)
)(chain_keys, last_states)
Expand Down Expand Up @@ -465,11 +468,12 @@ def test_chees(self, jitter_generator):
inference_algorithm = blackjax.dynamic_hmc(logposterior_fn, **parameters)

chain_keys = jax.random.split(inference_key, num_chains)
_, states, _ = jax.vmap(
_, states = jax.vmap(
lambda key, state: run_inference_algorithm(
rng_key=key,
initial_state=state,
inference_algorithm=inference_algorithm,
transform=lambda state, info: state,
num_steps=100,
)
)(chain_keys, last_states)
Expand All @@ -494,10 +498,11 @@ def test_barker(self):
barker = blackjax.barker_proposal(logposterior_fn, 1e-1)
state = barker.init({"coefs": 1.0, "log_scale": 1.0})

_, states, _ = run_inference_algorithm(
_, states = run_inference_algorithm(
rng_key=inference_key,
initial_state=state,
inference_algorithm=barker,
transform=lambda state, info: state,
num_steps=10_000,
)

Expand Down Expand Up @@ -679,10 +684,11 @@ def test_latent_gaussian(self):

initial_state = inference_algorithm.init(jnp.zeros((1,)))

_, states, _ = self.variant(
_, states = self.variant(
functools.partial(
run_inference_algorithm,
inference_algorithm=inference_algorithm,
transform=lambda state, info: state,
num_steps=self.sampling_steps,
),
)(rng_key=self.key, initial_state=initial_state)
Expand Down Expand Up @@ -724,7 +730,7 @@ def univariate_normal_test_case(
**kwargs,
):
inference_key, orbit_key = jax.random.split(rng_key)
_, states, _ = self.variant(
_, (states, info) = self.variant(
functools.partial(
run_inference_algorithm,
inference_algorithm=inference_algorithm,
Expand Down Expand Up @@ -855,7 +861,7 @@ def postprocess_samples(states, key):
20_000,
burnin,
postprocess_samples,
transform=lambda x: (x.positions, x.weights),
transform=lambda state, info: ((state.positions, state.weights), info),
)

@chex.all_variants(with_pmap=False)
Expand Down Expand Up @@ -997,10 +1003,11 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal):
functools.partial(
run_inference_algorithm,
inference_algorithm=inference_algorithm,
transform=lambda state, info: state,
num_steps=2_000,
)
)
_, states, _ = inference_loop_multiple_chains(
_, states = inference_loop_multiple_chains(
rng_key=multi_chain_sample_key, initial_state=initial_states
)

Expand Down
Loading
Loading