Skip to content

Commit

Permalink
Return adaptation extra information
Browse files Browse the repository at this point in the history
We currently only return the last state, the values of the parameter and
the adapted kernel. However, the full chain and intermediate adaptation
states can be useful when debugging inference.

In this PR we make `window_adaptation`, `meads_adaptation` and
`pathfinder_adaptation` return this extra information.
  • Loading branch information
rlouf committed Jan 16, 2023
1 parent feb810f commit 130e322
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 24 deletions.
31 changes: 22 additions & 9 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,12 @@ class AdaptationResults(NamedTuple):
parameters: dict


class AdaptationInfo(NamedTuple):
state: NamedTuple
info: NamedTuple
adaptation_state: NamedTuple


def window_adaptation(
algorithm: Union[hmc, nuts],
logdensity_fn: Callable,
Expand Down Expand Up @@ -757,7 +763,7 @@ def one_step(carry, xs):

return (
(new_state, new_adaptation_state),
(new_state, info, new_adaptation_state),
AdaptationInfo(new_state, info, new_adaptation_state),
)

def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 1000):
Expand All @@ -773,7 +779,7 @@ def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 1000):

keys = jax.random.split(rng_key, num_steps)
schedule = adaptation.window_adaptation.schedule(num_steps)
last_state, adaptation_chain = jax.lax.scan(
last_state, info = jax.lax.scan(
one_step_,
(init_state, init_adaptation_state),
(jnp.arange(num_steps), keys, schedule),
Expand All @@ -790,7 +796,14 @@ def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 1000):
def kernel(rng_key, state):
return step_fn(rng_key, state, logdensity_fn, **parameters)

return AdaptationResults(last_chain_state, kernel, parameters)
return (
AdaptationResults(
last_chain_state,
kernel,
parameters,
),
info,
)

return AdaptationAlgorithm(run)

Expand Down Expand Up @@ -861,7 +874,7 @@ def kernel(rng_key, state):
adaptation_state, new_states.position, new_states.logdensity_grad
)

return (new_states, new_adaptation_state), (
return (new_states, new_adaptation_state), AdaptationInfo(
new_states,
info,
new_adaptation_state,
Expand All @@ -876,7 +889,7 @@ def run(rng_key: PRNGKey, positions: PyTree, num_steps: int = 1000):
init_adaptation_state = init(positions, init_states.logdensity_grad)

keys = jax.random.split(key_adapt, num_steps)
(last_states, last_adaptation_state), _ = jax.lax.scan(
(last_states, last_adaptation_state), info = jax.lax.scan(
one_step, (init_states, init_adaptation_state), keys
)

Expand All @@ -895,7 +908,7 @@ def kernel(rng_key, state):
**parameters,
)

return AdaptationResults(last_states, kernel, parameters)
return AdaptationResults(last_states, kernel, parameters), info

return AdaptationAlgorithm(run) # type: ignore[arg-type]

Expand Down Expand Up @@ -1345,7 +1358,7 @@ def one_step(carry, rng_key):
)
return (
(new_state, new_adaptation_state),
(new_state, info, new_adaptation_state.ss_state),
AdaptationInfo(new_state, info, new_adaptation_state),
)

def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 400):
Expand All @@ -1366,7 +1379,7 @@ def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 400):
init_state = algorithm.init(init_position, logdensity_fn)

keys = jax.random.split(rng_key, num_steps)
last_state, warmup_chain = jax.lax.scan(
last_state, info = jax.lax.scan(
one_step,
(init_state, init_warmup_state),
keys,
Expand All @@ -1383,7 +1396,7 @@ def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 400):
def kernel(rng_key, state):
return step_fn(rng_key, state, logdensity_fn, **parameters)

return AdaptationResults(last_chain_state, kernel, parameters)
return AdaptationResults(last_chain_state, kernel, parameters), info

return AdaptationAlgorithm(run)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/HierarchicalBNN.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def fit_and_eval(
# warm up
adapt = blackjax.window_adaptation(blackjax.nuts, logprob)
final_state, kernel, _ = adapt.run(warmup_key, initial_position, num_warmup)
(final_state, kernel, _), _ = adapt.run(warmup_key, initial_position, num_warmup)
# inference
states = inference_loop(inference_key, kernel, final_state, num_samples)
Expand Down
6 changes: 3 additions & 3 deletions docs/examples/change_of_variable_hmc.md
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ init_params = jax.vmap(init_param_fn)(keys)
@jax.vmap
def call_warmup(seed, param):
initial_states, _, tuned_params = warmup.run(seed, param, 1000)
(initial_states, _, tuned_params), _ = warmup.run(seed, param, 1000)
return initial_states, tuned_params
initial_states, tuned_params = jax.jit(call_warmup)(keys, init_params)
Expand Down Expand Up @@ -468,7 +468,7 @@ init_params = jax.vmap(init_param_fn)(keys)
@jax.vmap
def call_warmup(seed, param):
initial_states, _, tuned_params = warmup.run(seed, param, 1000)
(initial_states, _, tuned_params), _ = warmup.run(seed, param, 1000)
return initial_states, tuned_params
initial_states, tuned_params = call_warmup(keys, init_params)
Expand Down Expand Up @@ -565,7 +565,7 @@ keys = jax.random.split(warmup_key, n_chains)
@jax.vmap
def call_warmup(seed, param):
initial_states, _, tuned_params = warmup.run(seed, param, 1000)
(initial_states, _, tuned_params), _ = warmup.run(seed, param, 1000)
return initial_states, tuned_params
initial_states, tuned_params = call_warmup(keys, init_params)
Expand Down
4 changes: 3 additions & 1 deletion docs/examples/howto_use_aesara.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ n_adapt = 3000
n_samples = 1000
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
state, kernel, _ = adapt.run(rng_key, init_position, n_adapt)
results, _ = adapt.run(rng_key, init_position, n_adapt)
state = results.state
kernel = results.kernel
states, infos = inference_loop(
rng_key, kernel, state, n_samples
Expand Down
6 changes: 5 additions & 1 deletion docs/examples/howto_use_numpyro.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
state = results.state
kernel = results.kernel
---
jupytext:
text_representation:
Expand Down Expand Up @@ -94,7 +96,9 @@ num_warmup = 2000
adapt = blackjax.window_adaptation(
blackjax.nuts, logdensity_fn, target_acceptance_rate=0.8
)
last_state, kernel, _ = adapt.run(rng_key, initial_position, num_warmup)
results, _ = adapt.run(rng_key, initial_position, num_warmup)
last_state = results.last_state
kernel = results.kernel
```

Let us now perform inference with the tuned kernel:
Expand Down
4 changes: 3 additions & 1 deletion docs/examples/howto_use_oryx.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ import blackjax
rng_key = jax.random.PRNGKey(0)
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
last_state, kernel, _ = adapt.run(rng_key, initial_weights, 100)
results, _ = adapt.run(rng_key, initial_weights, 100)
last_state = results.state
kernel = results.kernel
```

and sample from the model's posterior distribution:
Expand Down
4 changes: 3 additions & 1 deletion docs/examples/howto_use_pymc.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ init_position = [init_position_dict[rv] for rv in rvs]
rng_key = jax.random.PRNGKey(1234)
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
last_state, kernel, _ = adapt.run(rng_key, init_position, 1000)
results, _ = adapt.run(rng_key, init_position, 1000)
last_state = results.state
kernel = results.kernel
```

Let us now perform inference with the tuned kernel:
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/howto_use_tfp.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ adapt = blackjax.window_adaptation(
blackjax.hmc, logdensity_fn, num_integration_steps=3
)
last_state, kernel, _ = adapt.run(rng_key, initial_position, 1000)
(last_state, kernel, _), _ = adapt.run(rng_key, initial_position, 1000)
```

We can now perform inference with the tuned kernel:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def run_regression(algorithm, **parameters):
is_mass_matrix_diagonal=False,
**parameters,
)
state, kernel, _ = warmup.run(warmup_key, {"log_scale": 0.0, "coefs": 2.0}, 1000)
(state, kernel, _), _ = warmup.run(
warmup_key, {"log_scale": 0.0, "coefs": 2.0}, 1000
)

states = inference_loop(kernel, 10_000, inference_key, state)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def logdensity_fn(x):
target_acceptance_rate=0.8,
num_integration_steps=10,
)
state, kernel, _ = warmup.run(rng_key, 1.0, num_steps=100)
(state, kernel, _), _ = warmup.run(rng_key, 1.0, num_steps=100)
kernel = jax.jit(kernel)

for _ in range(10):
Expand All @@ -118,7 +118,7 @@ def logdensity_fn(x):
logdensity_fn=logdensity_fn,
target_acceptance_rate=0.8,
)
state, kernel, _ = warmup.run(rng_key, 1.0, num_steps=100)
(state, kernel, _), _ = warmup.run(rng_key, 1.0, num_steps=100)
step = jax.jit(kernel)

for _ in range(10):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal):
progress_bar=True,
**case["parameters"],
)
state, kernel, _ = warmup.run(
(state, kernel, _), _ = warmup.run(
warmup_key,
case["initial_position"],
case["num_warmup_steps"],
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_pathfinder_adaptation(
logposterior_fn,
**parameters,
)
state, kernel, _ = warmup.run(
(state, kernel, _), _ = warmup.run(
warmup_key,
initial_position,
num_warmup_steps,
Expand Down Expand Up @@ -200,7 +200,7 @@ def test_meads(self):
log_scales = 1.0 + jax.random.normal(scale_key, (num_chains,))
coefs = 4.0 + jax.random.normal(coefs_key, (num_chains,))
initial_positions = {"log_scale": log_scales, "coefs": coefs}
last_states, kernel, _ = warmup.run(
(last_states, kernel, _), _ = warmup.run(
warmup_key,
initial_positions,
num_steps=1000,
Expand Down

0 comments on commit 130e322

Please sign in to comment.