From 130e3227dce1e2bb74f270c8c24ba6db3623d4de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 16 Jan 2023 10:27:00 +0100 Subject: [PATCH] Return adaptation extra information We currently only return the last state, the values of the parameter and the adapted kernel. However, the full chain and intermediate adaptation states can be useful when debugging inference. In this PR we make `window_adaptation`, `meads_adaptation` and `pathfinder_adaptation` return this extra information. --- blackjax/kernels.py | 31 ++++++++++++++++++------- docs/examples/HierarchicalBNN.md | 2 +- docs/examples/change_of_variable_hmc.md | 6 ++--- docs/examples/howto_use_aesara.md | 4 +++- docs/examples/howto_use_numpyro.md | 6 ++++- docs/examples/howto_use_oryx.md | 4 +++- docs/examples/howto_use_pymc.md | 4 +++- docs/examples/howto_use_tfp.md | 2 +- tests/test_benchmarks.py | 4 +++- tests/test_compilation.py | 4 ++-- tests/test_sampling.py | 6 ++--- 11 files changed, 49 insertions(+), 24 deletions(-) diff --git a/blackjax/kernels.py b/blackjax/kernels.py index dea854963..08c2db0fb 100644 --- a/blackjax/kernels.py +++ b/blackjax/kernels.py @@ -682,6 +682,12 @@ class AdaptationResults(NamedTuple): parameters: dict +class AdaptationInfo(NamedTuple): + state: NamedTuple + info: NamedTuple + adaptation_state: NamedTuple + + def window_adaptation( algorithm: Union[hmc, nuts], logdensity_fn: Callable, @@ -757,7 +763,7 @@ def one_step(carry, xs): return ( (new_state, new_adaptation_state), - (new_state, info, new_adaptation_state), + AdaptationInfo(new_state, info, new_adaptation_state), ) def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 1000): @@ -773,7 +779,7 @@ def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 1000): keys = jax.random.split(rng_key, num_steps) schedule = adaptation.window_adaptation.schedule(num_steps) - last_state, adaptation_chain = jax.lax.scan( + last_state, info = jax.lax.scan( one_step_, (init_state, init_adaptation_state), (jnp.arange(num_steps), keys, schedule), @@ -790,7 +796,14 @@ def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 1000): def kernel(rng_key, state): return step_fn(rng_key, state, logdensity_fn, **parameters) - return AdaptationResults(last_chain_state, kernel, parameters) + return ( + AdaptationResults( + last_chain_state, + kernel, + parameters, + ), + info, + ) return AdaptationAlgorithm(run) @@ -861,7 +874,7 @@ def kernel(rng_key, state): adaptation_state, new_states.position, new_states.logdensity_grad ) - return (new_states, new_adaptation_state), ( + return (new_states, new_adaptation_state), AdaptationInfo( new_states, info, new_adaptation_state, @@ -876,7 +889,7 @@ def run(rng_key: PRNGKey, positions: PyTree, num_steps: int = 1000): init_adaptation_state = init(positions, init_states.logdensity_grad) keys = jax.random.split(key_adapt, num_steps) - (last_states, last_adaptation_state), _ = jax.lax.scan( + (last_states, last_adaptation_state), info = jax.lax.scan( one_step, (init_states, init_adaptation_state), keys ) @@ -895,7 +908,7 @@ def kernel(rng_key, state): **parameters, ) - return AdaptationResults(last_states, kernel, parameters) + return AdaptationResults(last_states, kernel, parameters), info return AdaptationAlgorithm(run) # type: ignore[arg-type] @@ -1345,7 +1358,7 @@ def one_step(carry, rng_key): ) return ( (new_state, new_adaptation_state), - (new_state, info, new_adaptation_state.ss_state), + AdaptationInfo(new_state, info, new_adaptation_state), ) def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 400): @@ -1366,7 +1379,7 @@ def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 400): init_state = algorithm.init(init_position, logdensity_fn) keys = jax.random.split(rng_key, num_steps) - last_state, warmup_chain = jax.lax.scan( + last_state, info = jax.lax.scan( one_step, (init_state, init_warmup_state), keys, @@ -1383,7 +1396,7 @@ def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 400): def kernel(rng_key, state): return step_fn(rng_key, state, logdensity_fn, **parameters) - return AdaptationResults(last_chain_state, kernel, parameters) + return AdaptationResults(last_chain_state, kernel, parameters), info return AdaptationAlgorithm(run) diff --git a/docs/examples/HierarchicalBNN.md b/docs/examples/HierarchicalBNN.md index 554b6fc4f..42eb5926b 100644 --- a/docs/examples/HierarchicalBNN.md +++ b/docs/examples/HierarchicalBNN.md @@ -181,7 +181,7 @@ def fit_and_eval( # warm up adapt = blackjax.window_adaptation(blackjax.nuts, logprob) - final_state, kernel, _ = adapt.run(warmup_key, initial_position, num_warmup) + (final_state, kernel, _), _ = adapt.run(warmup_key, initial_position, num_warmup) # inference states = inference_loop(inference_key, kernel, final_state, num_samples) diff --git a/docs/examples/change_of_variable_hmc.md b/docs/examples/change_of_variable_hmc.md index bf97f2000..18aeb8bc7 100644 --- a/docs/examples/change_of_variable_hmc.md +++ b/docs/examples/change_of_variable_hmc.md @@ -301,7 +301,7 @@ init_params = jax.vmap(init_param_fn)(keys) @jax.vmap def call_warmup(seed, param): - initial_states, _, tuned_params = warmup.run(seed, param, 1000) + (initial_states, _, tuned_params), _ = warmup.run(seed, param, 1000) return initial_states, tuned_params initial_states, tuned_params = jax.jit(call_warmup)(keys, init_params) @@ -468,7 +468,7 @@ init_params = jax.vmap(init_param_fn)(keys) @jax.vmap def call_warmup(seed, param): - initial_states, _, tuned_params = warmup.run(seed, param, 1000) + (initial_states, _, tuned_params), _ = warmup.run(seed, param, 1000) return initial_states, tuned_params initial_states, tuned_params = call_warmup(keys, init_params) @@ -565,7 +565,7 @@ keys = jax.random.split(warmup_key, n_chains) @jax.vmap def call_warmup(seed, param): - initial_states, _, tuned_params = warmup.run(seed, param, 1000) + (initial_states, _, tuned_params), _ = warmup.run(seed, param, 1000) return initial_states, tuned_params initial_states, tuned_params = call_warmup(keys, init_params) diff --git a/docs/examples/howto_use_aesara.md b/docs/examples/howto_use_aesara.md index c25085da6..489b199e0 100644 --- a/docs/examples/howto_use_aesara.md +++ b/docs/examples/howto_use_aesara.md @@ -166,7 +166,9 @@ n_adapt = 3000 n_samples = 1000 adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn) -state, kernel, _ = adapt.run(rng_key, init_position, n_adapt) +results, _ = adapt.run(rng_key, init_position, n_adapt) +state = results.state +kernel = results.kernel states, infos = inference_loop( rng_key, kernel, state, n_samples diff --git a/docs/examples/howto_use_numpyro.md b/docs/examples/howto_use_numpyro.md index 94e550f97..3aa8d96c0 100644 --- a/docs/examples/howto_use_numpyro.md +++ b/docs/examples/howto_use_numpyro.md @@ -1,3 +1,5 @@ +state = results.state +kernel = results.kernel --- jupytext: text_representation: @@ -94,7 +96,9 @@ num_warmup = 2000 adapt = blackjax.window_adaptation( blackjax.nuts, logdensity_fn, target_acceptance_rate=0.8 ) -last_state, kernel, _ = adapt.run(rng_key, initial_position, num_warmup) +results, _ = adapt.run(rng_key, initial_position, num_warmup) +last_state = results.last_state +kernel = results.kernel ``` Let us now perform inference with the tuned kernel: diff --git a/docs/examples/howto_use_oryx.md b/docs/examples/howto_use_oryx.md index 581b3b0b3..80b937832 100644 --- a/docs/examples/howto_use_oryx.md +++ b/docs/examples/howto_use_oryx.md @@ -139,7 +139,9 @@ import blackjax rng_key = jax.random.PRNGKey(0) adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn) -last_state, kernel, _ = adapt.run(rng_key, initial_weights, 100) +results, _ = adapt.run(rng_key, initial_weights, 100) +last_state = results.state +kernel = results.kernel ``` and sample from the model's posterior distribution: diff --git a/docs/examples/howto_use_pymc.md b/docs/examples/howto_use_pymc.md index 379b4b8b2..001c0fe94 100644 --- a/docs/examples/howto_use_pymc.md +++ b/docs/examples/howto_use_pymc.md @@ -70,7 +70,9 @@ init_position = [init_position_dict[rv] for rv in rvs] rng_key = jax.random.PRNGKey(1234) adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn) -last_state, kernel, _ = adapt.run(rng_key, init_position, 1000) +results, _ = adapt.run(rng_key, init_position, 1000) +last_state = results.state +kernel = results.kernel ``` Let us now perform inference with the tuned kernel: diff --git a/docs/examples/howto_use_tfp.md b/docs/examples/howto_use_tfp.md index 90c8b06ec..e457ab1c9 100644 --- a/docs/examples/howto_use_tfp.md +++ b/docs/examples/howto_use_tfp.md @@ -120,7 +120,7 @@ adapt = blackjax.window_adaptation( blackjax.hmc, logdensity_fn, num_integration_steps=3 ) -last_state, kernel, _ = adapt.run(rng_key, initial_position, 1000) +(last_state, kernel, _), _ = adapt.run(rng_key, initial_position, 1000) ``` We can now perform inference with the tuned kernel: diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index 888145b4e..8833d1a4b 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -54,7 +54,9 @@ def run_regression(algorithm, **parameters): is_mass_matrix_diagonal=False, **parameters, ) - state, kernel, _ = warmup.run(warmup_key, {"log_scale": 0.0, "coefs": 2.0}, 1000) + (state, kernel, _), _ = warmup.run( + warmup_key, {"log_scale": 0.0, "coefs": 2.0}, 1000 + ) states = inference_loop(kernel, 10_000, inference_key, state) diff --git a/tests/test_compilation.py b/tests/test_compilation.py index c80816b57..eb507133e 100644 --- a/tests/test_compilation.py +++ b/tests/test_compilation.py @@ -91,7 +91,7 @@ def logdensity_fn(x): target_acceptance_rate=0.8, num_integration_steps=10, ) - state, kernel, _ = warmup.run(rng_key, 1.0, num_steps=100) + (state, kernel, _), _ = warmup.run(rng_key, 1.0, num_steps=100) kernel = jax.jit(kernel) for _ in range(10): @@ -118,7 +118,7 @@ def logdensity_fn(x): logdensity_fn=logdensity_fn, target_acceptance_rate=0.8, ) - state, kernel, _ = warmup.run(rng_key, 1.0, num_steps=100) + (state, kernel, _), _ = warmup.run(rng_key, 1.0, num_steps=100) step = jax.jit(kernel) for _ in range(10): diff --git a/tests/test_sampling.py b/tests/test_sampling.py index 17b203588..9ae5910b1 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -99,7 +99,7 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal): progress_bar=True, **case["parameters"], ) - state, kernel, _ = warmup.run( + (state, kernel, _), _ = warmup.run( warmup_key, case["initial_position"], case["num_warmup_steps"], @@ -164,7 +164,7 @@ def test_pathfinder_adaptation( logposterior_fn, **parameters, ) - state, kernel, _ = warmup.run( + (state, kernel, _), _ = warmup.run( warmup_key, initial_position, num_warmup_steps, @@ -200,7 +200,7 @@ def test_meads(self): log_scales = 1.0 + jax.random.normal(scale_key, (num_chains,)) coefs = 4.0 + jax.random.normal(coefs_key, (num_chains,)) initial_positions = {"log_scale": log_scales, "coefs": coefs} - last_states, kernel, _ = warmup.run( + (last_states, kernel, _), _ = warmup.run( warmup_key, initial_positions, num_steps=1000,