diff --git a/blackjax/kernels.py b/blackjax/kernels.py index dea854963..08c2db0fb 100644 --- a/blackjax/kernels.py +++ b/blackjax/kernels.py @@ -682,6 +682,12 @@ class AdaptationResults(NamedTuple): parameters: dict +class AdaptationInfo(NamedTuple): + state: NamedTuple + info: NamedTuple + adaptation_state: NamedTuple + + def window_adaptation( algorithm: Union[hmc, nuts], logdensity_fn: Callable, @@ -757,7 +763,7 @@ def one_step(carry, xs): return ( (new_state, new_adaptation_state), - (new_state, info, new_adaptation_state), + AdaptationInfo(new_state, info, new_adaptation_state), ) def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 1000): @@ -773,7 +779,7 @@ def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 1000): keys = jax.random.split(rng_key, num_steps) schedule = adaptation.window_adaptation.schedule(num_steps) - last_state, adaptation_chain = jax.lax.scan( + last_state, info = jax.lax.scan( one_step_, (init_state, init_adaptation_state), (jnp.arange(num_steps), keys, schedule), @@ -790,7 +796,14 @@ def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 1000): def kernel(rng_key, state): return step_fn(rng_key, state, logdensity_fn, **parameters) - return AdaptationResults(last_chain_state, kernel, parameters) + return ( + AdaptationResults( + last_chain_state, + kernel, + parameters, + ), + info, + ) return AdaptationAlgorithm(run) @@ -861,7 +874,7 @@ def kernel(rng_key, state): adaptation_state, new_states.position, new_states.logdensity_grad ) - return (new_states, new_adaptation_state), ( + return (new_states, new_adaptation_state), AdaptationInfo( new_states, info, new_adaptation_state, @@ -876,7 +889,7 @@ def run(rng_key: PRNGKey, positions: PyTree, num_steps: int = 1000): init_adaptation_state = init(positions, init_states.logdensity_grad) keys = jax.random.split(key_adapt, num_steps) - (last_states, last_adaptation_state), _ = jax.lax.scan( + (last_states, last_adaptation_state), info = jax.lax.scan( one_step, (init_states, init_adaptation_state), keys ) @@ -895,7 +908,7 @@ def kernel(rng_key, state): **parameters, ) - return AdaptationResults(last_states, kernel, parameters) + return AdaptationResults(last_states, kernel, parameters), info return AdaptationAlgorithm(run) # type: ignore[arg-type] @@ -1345,7 +1358,7 @@ def one_step(carry, rng_key): ) return ( (new_state, new_adaptation_state), - (new_state, info, new_adaptation_state.ss_state), + AdaptationInfo(new_state, info, new_adaptation_state), ) def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 400): @@ -1366,7 +1379,7 @@ def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 400): init_state = algorithm.init(init_position, logdensity_fn) keys = jax.random.split(rng_key, num_steps) - last_state, warmup_chain = jax.lax.scan( + last_state, info = jax.lax.scan( one_step, (init_state, init_warmup_state), keys, @@ -1383,7 +1396,7 @@ def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 400): def kernel(rng_key, state): return step_fn(rng_key, state, logdensity_fn, **parameters) - return AdaptationResults(last_chain_state, kernel, parameters) + return AdaptationResults(last_chain_state, kernel, parameters), info return AdaptationAlgorithm(run) diff --git a/docs/examples/HierarchicalBNN.md b/docs/examples/HierarchicalBNN.md index 554b6fc4f..42eb5926b 100644 --- a/docs/examples/HierarchicalBNN.md +++ b/docs/examples/HierarchicalBNN.md @@ -181,7 +181,7 @@ def fit_and_eval( # warm up adapt = blackjax.window_adaptation(blackjax.nuts, logprob) - final_state, kernel, _ = adapt.run(warmup_key, initial_position, num_warmup) + (final_state, kernel, _), _ = adapt.run(warmup_key, initial_position, num_warmup) # inference states = inference_loop(inference_key, kernel, final_state, num_samples) diff --git a/docs/examples/change_of_variable_hmc.md b/docs/examples/change_of_variable_hmc.md index bf97f2000..18aeb8bc7 100644 --- a/docs/examples/change_of_variable_hmc.md +++ b/docs/examples/change_of_variable_hmc.md @@ -301,7 +301,7 @@ init_params = jax.vmap(init_param_fn)(keys) @jax.vmap def call_warmup(seed, param): - initial_states, _, tuned_params = warmup.run(seed, param, 1000) + (initial_states, _, tuned_params), _ = warmup.run(seed, param, 1000) return initial_states, tuned_params initial_states, tuned_params = jax.jit(call_warmup)(keys, init_params) @@ -468,7 +468,7 @@ init_params = jax.vmap(init_param_fn)(keys) @jax.vmap def call_warmup(seed, param): - initial_states, _, tuned_params = warmup.run(seed, param, 1000) + (initial_states, _, tuned_params), _ = warmup.run(seed, param, 1000) return initial_states, tuned_params initial_states, tuned_params = call_warmup(keys, init_params) @@ -565,7 +565,7 @@ keys = jax.random.split(warmup_key, n_chains) @jax.vmap def call_warmup(seed, param): - initial_states, _, tuned_params = warmup.run(seed, param, 1000) + (initial_states, _, tuned_params), _ = warmup.run(seed, param, 1000) return initial_states, tuned_params initial_states, tuned_params = call_warmup(keys, init_params) diff --git a/docs/examples/howto_use_aesara.md b/docs/examples/howto_use_aesara.md index c25085da6..489b199e0 100644 --- a/docs/examples/howto_use_aesara.md +++ b/docs/examples/howto_use_aesara.md @@ -166,7 +166,9 @@ n_adapt = 3000 n_samples = 1000 adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn) -state, kernel, _ = adapt.run(rng_key, init_position, n_adapt) +results, _ = adapt.run(rng_key, init_position, n_adapt) +state = results.state +kernel = results.kernel states, infos = inference_loop( rng_key, kernel, state, n_samples diff --git a/docs/examples/howto_use_numpyro.md b/docs/examples/howto_use_numpyro.md index 94e550f97..3aa8d96c0 100644 --- a/docs/examples/howto_use_numpyro.md +++ b/docs/examples/howto_use_numpyro.md @@ -1,3 +1,5 @@ +state = results.state +kernel = results.kernel --- jupytext: text_representation: @@ -94,7 +96,9 @@ num_warmup = 2000 adapt = blackjax.window_adaptation( blackjax.nuts, logdensity_fn, target_acceptance_rate=0.8 ) -last_state, kernel, _ = adapt.run(rng_key, initial_position, num_warmup) +results, _ = adapt.run(rng_key, initial_position, num_warmup) +last_state = results.last_state +kernel = results.kernel ``` Let us now perform inference with the tuned kernel: diff --git a/docs/examples/howto_use_oryx.md b/docs/examples/howto_use_oryx.md index 581b3b0b3..80b937832 100644 --- a/docs/examples/howto_use_oryx.md +++ b/docs/examples/howto_use_oryx.md @@ -139,7 +139,9 @@ import blackjax rng_key = jax.random.PRNGKey(0) adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn) -last_state, kernel, _ = adapt.run(rng_key, initial_weights, 100) +results, _ = adapt.run(rng_key, initial_weights, 100) +last_state = results.state +kernel = results.kernel ``` and sample from the model's posterior distribution: diff --git a/docs/examples/howto_use_pymc.md b/docs/examples/howto_use_pymc.md index 379b4b8b2..001c0fe94 100644 --- a/docs/examples/howto_use_pymc.md +++ b/docs/examples/howto_use_pymc.md @@ -70,7 +70,9 @@ init_position = [init_position_dict[rv] for rv in rvs] rng_key = jax.random.PRNGKey(1234) adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn) -last_state, kernel, _ = adapt.run(rng_key, init_position, 1000) +results, _ = adapt.run(rng_key, init_position, 1000) +last_state = results.state +kernel = results.kernel ``` Let us now perform inference with the tuned kernel: diff --git a/docs/examples/howto_use_tfp.md b/docs/examples/howto_use_tfp.md index 90c8b06ec..e457ab1c9 100644 --- a/docs/examples/howto_use_tfp.md +++ b/docs/examples/howto_use_tfp.md @@ -120,7 +120,7 @@ adapt = blackjax.window_adaptation( blackjax.hmc, logdensity_fn, num_integration_steps=3 ) -last_state, kernel, _ = adapt.run(rng_key, initial_position, 1000) +(last_state, kernel, _), _ = adapt.run(rng_key, initial_position, 1000) ``` We can now perform inference with the tuned kernel: diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index 888145b4e..8833d1a4b 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -54,7 +54,9 @@ def run_regression(algorithm, **parameters): is_mass_matrix_diagonal=False, **parameters, ) - state, kernel, _ = warmup.run(warmup_key, {"log_scale": 0.0, "coefs": 2.0}, 1000) + (state, kernel, _), _ = warmup.run( + warmup_key, {"log_scale": 0.0, "coefs": 2.0}, 1000 + ) states = inference_loop(kernel, 10_000, inference_key, state) diff --git a/tests/test_compilation.py b/tests/test_compilation.py index c80816b57..eb507133e 100644 --- a/tests/test_compilation.py +++ b/tests/test_compilation.py @@ -91,7 +91,7 @@ def logdensity_fn(x): target_acceptance_rate=0.8, num_integration_steps=10, ) - state, kernel, _ = warmup.run(rng_key, 1.0, num_steps=100) + (state, kernel, _), _ = warmup.run(rng_key, 1.0, num_steps=100) kernel = jax.jit(kernel) for _ in range(10): @@ -118,7 +118,7 @@ def logdensity_fn(x): logdensity_fn=logdensity_fn, target_acceptance_rate=0.8, ) - state, kernel, _ = warmup.run(rng_key, 1.0, num_steps=100) + (state, kernel, _), _ = warmup.run(rng_key, 1.0, num_steps=100) step = jax.jit(kernel) for _ in range(10): diff --git a/tests/test_sampling.py b/tests/test_sampling.py index 17b203588..9ae5910b1 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -99,7 +99,7 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal): progress_bar=True, **case["parameters"], ) - state, kernel, _ = warmup.run( + (state, kernel, _), _ = warmup.run( warmup_key, case["initial_position"], case["num_warmup_steps"], @@ -164,7 +164,7 @@ def test_pathfinder_adaptation( logposterior_fn, **parameters, ) - state, kernel, _ = warmup.run( + (state, kernel, _), _ = warmup.run( warmup_key, initial_position, num_warmup_steps, @@ -200,7 +200,7 @@ def test_meads(self): log_scales = 1.0 + jax.random.normal(scale_key, (num_chains,)) coefs = 4.0 + jax.random.normal(coefs_key, (num_chains,)) initial_positions = {"log_scale": log_scales, "coefs": coefs} - last_states, kernel, _ = warmup.run( + (last_states, kernel, _), _ = warmup.run( warmup_key, initial_positions, num_steps=1000,