diff --git a/bayeux/_src/mcmc/blackjax.py b/bayeux/_src/mcmc/blackjax.py index 9faa219..add24fe 100644 --- a/bayeux/_src/mcmc/blackjax.py +++ b/bayeux/_src/mcmc/blackjax.py @@ -202,13 +202,16 @@ def _blackjax_inference( # return from `run_inference_algorithm` changes from # `_, states, infos` to `_, (states, infos)`. This one weird # trick handles both cases. - _, *states_and_infos = blackjax.util.run_inference_algorithm( + ret = blackjax.util.run_inference_algorithm( rng_key=seed, inference_algorithm=inference_algorithm, num_steps=num_draws, progress_bar=False, **{_INFERENCE_KWARG: adapt_state}) - return states_and_infos + if len(ret) == 2: # For newer blackjax versions (1.2.4+) + return ret[1] + else: # Delete this once blackjax 1.2.4 is stable + return ret[1:] def _blackjax_inference_loop(