Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify run_inference_algorithm #714

Merged
merged 21 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.flatten_util import ravel_pytree

from blackjax.diagnostics import effective_sample_size
from blackjax.util import pytree_size, streaming_average_update
from blackjax.util import incremental_value_update, pytree_size


class MCLMCAdaptationState(NamedTuple):
Expand Down Expand Up @@ -199,9 +199,9 @@ def step(iteration_state, weight_and_key):

x = ravel_pytree(state.position)[0]
# update the running average of x, x^2
streaming_avg = streaming_average_update(
current_value=jnp.array([x, jnp.square(x)]),
previous_weight_and_average=streaming_avg,
streaming_avg = incremental_value_update(
expectation=jnp.array([x, jnp.square(x)]),
incremental_val=streaming_avg,
weight=(1 - mask) * success * params.step_size,
zero_prevention=mask,
)
Expand Down
157 changes: 100 additions & 57 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from functools import partial
from typing import Callable, Union

import jax
import jax.numpy as jnp
from jax import jit, lax
from jax.flatten_util import ravel_pytree
from jax.random import normal, split
from jax.tree_util import tree_leaves
from jax.tree_util import tree_leaves, tree_map

from blackjax.base import SamplingAlgorithm, VIAlgorithm
from blackjax.progress_bar import gen_scan_fn
Expand Down Expand Up @@ -149,9 +148,7 @@ def run_inference_algorithm(
initial_state: ArrayLikeTree = None,
initial_position: ArrayLikeTree = None,
progress_bar: bool = False,
transform: Callable = lambda x: x,
return_state_history=True,
expectation: Callable = lambda x: x,
transform: Callable = lambda state, info: (state, info),
) -> tuple:
"""Wrapper to run an inference algorithm.

Expand All @@ -166,35 +163,22 @@ def run_inference_algorithm(
initial_state
The initial state of the inference algorithm.
initial_position
The initial position of the inference algorithm. This is used when the initial
state is not provided.
The initial position of the inference algorithm. This is used when the initial state is not provided.
inference_algorithm
One of blackjax's sampling algorithms or variational inference algorithms.
num_steps
Number of MCMC steps.
progress_bar
Whether to display a progress bar.
transform
A transformation of the trace of states to be returned. This is useful for
A transformation of the trace of states (and info) to be returned. This is useful for
computing determinstic variables, or returning a subset of the states.
By default, the states are returned as is.
expectation
A function that computes the expectation of the state. This is done
incrementally, so doesn't require storing all the states.
return_state_history
if False, `run_inference_algorithm` will only return an expectation of the value
of transform, and return that average instead of the full set of samples. This
is useful when memory is a bottleneck.

Returns
-------
If return_state_history is True:
1. The final state.
2. The trace of the state.
3. The trace of the info of the inference algorithm for diagnostics.
If return_state_history is False:
1. This is the expectation of state over the chain. Otherwise the final state.
2. The final state of the inference algorithm.
2. The history of states.
"""

if initial_state is None and initial_position is None:
Expand All @@ -212,58 +196,117 @@ def run_inference_algorithm(

keys = split(rng_key, num_steps)

def one_step(average_and_state, xs, return_state):
def one_step(state, xs):
_, rng_key = xs
average, state = average_and_state
state, info = inference_algorithm.step(rng_key, state)
average = streaming_average_update(expectation(transform(state)), average)
if return_state:
return (average, state), (transform(state), info)
else:
return (average, state), None
return state, transform(state, info)

one_step = jax.jit(partial(one_step, return_state=return_state_history))

xs = (jnp.arange(num_steps), keys)
scan_fn = gen_scan_fn(num_steps, progress_bar)
((_, average), final_state), history = scan_fn(
one_step,
((0, expectation(transform(initial_state))), initial_state),
xs,
)

if not return_state_history:
return average, transform(final_state)
else:
state_history, info_history = history
return transform(final_state), state_history, info_history
xs = jnp.arange(num_steps), keys
final_state, history = scan_fn(one_step, initial_state, xs)

return final_state, history


def streaming_average_update(
current_value, previous_weight_and_average, weight=1.0, zero_prevention=0.0
def store_only_expectation_values(
sampling_algorithm,
state_transform=lambda x: x,
incremental_value_transform=lambda x: x,
burn_in=0,
):
"""Takes a sampling algorithm and constructs from it a new sampling algorithm object. The new sampling algorithm has the same
kernel but only stores the streaming expectation values of some observables, not the full states; to save memory.

It saves incremental_value_transform(E[state_transform(x)]) at each step i, where expectation is computed with samples up to i-th sample.

Example:

.. code::

init_key, state_key, run_key = jax.random.split(jax.random.PRNGKey(0),3)
model = StandardNormal(2)
initial_position = model.sample_init(init_key)
initial_state = blackjax.mcmc.mclmc.init(
position=initial_position, logdensity_fn=model.logdensity_fn, rng_key=state_key
)
integrator_type = "mclachlan"
L = 1.0
step_size = 0.1
num_steps = 4

integrator = map_integrator_type_to_integrator['mclmc'][integrator_type]
state_transform = lambda state: state.position
memory_efficient_sampling_alg, transform = store_only_expectation_values(
sampling_algorithm=sampling_alg,
state_transform=state_transform)

initial_state = memory_efficient_sampling_alg.init(initial_state)

final_state, trace_at_every_step = run_inference_algorithm(

rng_key=run_key,
initial_state=initial_state,
inference_algorithm=memory_efficient_sampling_alg,
num_steps=num_steps,
transform=transform,
progress_bar=True,
)
"""

def init_fn(state):
averaging_state = (0.0, state_transform(state))
return (state, averaging_state)

def update_fn(rng_key, state_and_incremental_val):
state, averaging_state = state_and_incremental_val
state, info = sampling_algorithm.step(
rng_key, state
) # update the state with the sampling algorithm
averaging_state = incremental_value_update(
state_transform(state),
averaging_state,
weight=(
averaging_state[0] >= burn_in
), # If we want to eliminate some number of steps as a burn-in
zero_prevention=1e-10 * (burn_in > 0),
)
# update the expectation value with the running average
return (state, averaging_state), info

def transform(state_and_incremental_val, info):
(state, (_, incremental_value)) = state_and_incremental_val
return incremental_value_transform(incremental_value), info

return SamplingAlgorithm(init_fn, update_fn), transform


def incremental_value_update(
expectation, incremental_val, weight=1.0, zero_prevention=0.0
):
"""Compute the streaming average of a function O(x) using a weight.
Parameters:
----------
current_value
the current value of the function that we want to take average of
previous_weight_and_average
tuple of (previous_weight, previous_average) where previous_weight is the
sum of weights and average is the current estimated average
expectation
the value of the expectation at the current timestep
incremental_val
tuple of (total, average) where total is the sum of weights and average is the current average
weight
weight of the current state
zero_prevention
small value to prevent division by zero
Returns:
----------
new total weight and streaming average
new streaming average
"""
previous_weight, previous_average = previous_weight_and_average
current_weight = previous_weight + weight
current_average = jax.tree.map(
lambda x, avg: (previous_weight * avg + weight * x)
/ (current_weight + zero_prevention),
current_value,
previous_average,

total, average = incremental_val
average = tree_map(
lambda exp, av: (total * av + weight * exp)
/ (total + weight + zero_prevention),
expectation,
average,
)
return current_weight, current_average
total += weight
incremental_val = total, average
return incremental_val
reubenharry marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion tests/adaptation/test_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_chees_adaptation(adaptation_filters):
algorithm = blackjax.dynamic_hmc(logprob_fn, **parameters)

chain_keys = jax.random.split(inference_key, num_chains)
_, _, infos = jax.vmap(
_, (_, infos) = jax.vmap(
lambda key, state: run_inference_algorithm(
rng_key=key,
initial_state=state,
Expand Down
Loading
Loading