From 40d167287ec41aefbfda0118e1dc87d35ae8cb1f Mon Sep 17 00:00:00 2001 From: Colin Carroll Date: Thu, 16 May 2024 12:20:24 -0700 Subject: [PATCH] Fix bayeux after blackjax update. PiperOrigin-RevId: 634490737 --- bayeux/_src/mcmc/blackjax.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/bayeux/_src/mcmc/blackjax.py b/bayeux/_src/mcmc/blackjax.py index 4b22e4d..62dd4bf 100644 --- a/bayeux/_src/mcmc/blackjax.py +++ b/bayeux/_src/mcmc/blackjax.py @@ -38,6 +38,13 @@ } +def _convert_algorithm(algorithm): + # Remove this after blackjax is stable + if hasattr(algorithm, "differentiable"): + return algorithm.differentiable + return algorithm + + def get_extra_kwargs(kwargs): defaults = { "chain_method": "vectorized", @@ -64,8 +71,8 @@ def get_kwargs(self, **kwargs): adapt_fn, algorithm, constrained_log_density, extra_parameters | kwargs) return {adapt_fn: adaptation_kwargs, "adapt.run": run_kwargs, - algorithm: get_algorithm_kwargs( - algorithm, constrained_log_density, kwargs), + _convert_algorithm(algorithm): get_algorithm_kwargs( + _convert_algorithm(algorithm), constrained_log_density, kwargs), "extra_parameters": extra_parameters} def __call__(self, seed, **kwargs): @@ -171,7 +178,7 @@ def _blackjax_inference( (states, infos), adaptation_parameters """ - algorithm_kwargs = kwargs[algorithm] | adapt_parameters + algorithm_kwargs = kwargs[_convert_algorithm(algorithm)] | adapt_parameters inference_algorithm = algorithm(**algorithm_kwargs) _, states, infos = blackjax.util.run_inference_algorithm( rng_key=seed, @@ -257,8 +264,8 @@ def get_adaptation_kwargs(adaptation_algorithm, algorithm, log_density, kwargs): adaptation_required.remove("algorithm") adaptation_kwargs["algorithm"] = algorithm adaptation_kwargs = ( - get_algorithm_kwargs(algorithm, log_density, kwargs) | adaptation_kwargs - ) + get_algorithm_kwargs(_convert_algorithm(algorithm), log_density, kwargs) + | adaptation_kwargs) adaptation_required = adaptation_required - adaptation_kwargs.keys()