diff --git a/blackjax/kernels.py b/blackjax/kernels.py index ac55abec2..9057a8f71 100644 --- a/blackjax/kernels.py +++ b/blackjax/kernels.py @@ -668,9 +668,12 @@ class csgld: Parameters ---------- - logdensity_estimator_fn + logdensity_estimator A function that returns an estimation of the model's logdensity given a position and a batch of data. + gradient_estimator + A function that takes a position, a batch of data and returns an estimation + of the gradient of the log-density at this position. zeta Hyperparameter that controls the geometric property of the flattened density. If `zeta=0` the function reduces to the SGLD step function. @@ -700,7 +703,8 @@ class csgld: def __new__( # type: ignore[misc] cls, - logdensity_estimator_fn: Callable, + logdensity_estimator: Callable, + gradient_estimator: Callable, zeta: float = 1, temperature: float = 0.01, num_partitions: int = 512, @@ -722,7 +726,8 @@ def step_fn( return step( rng_key, state, - logdensity_estimator_fn, + logdensity_estimator, + gradient_estimator, minibatch, step_size_diff, step_size_stoch, diff --git a/blackjax/sgmcmc/csgld.py b/blackjax/sgmcmc/csgld.py index c53bc6548..06a427492 100644 --- a/blackjax/sgmcmc/csgld.py +++ b/blackjax/sgmcmc/csgld.py @@ -64,7 +64,8 @@ def kernel(num_partitions=512, energy_gap=10, min_energy=0) -> Callable: def one_step( rng_key: PRNGKey, state: ContourSGLDState, - logdensity_estimator_fn: Callable, + logdensity_estimator: Callable, + gradient_estimator: Callable, minibatch: PyTree, step_size_diff: float, # step size for Langevin diffusion step_size_stoch: float = 1e-3, # step size for stochastic approximation @@ -95,9 +96,12 @@ def one_step( State of the pseudo-random number generator. state Current state of the CSGLD sampler - logdensity_estimator_fn + logdensity_estimator Function that returns an estimation of the value of the density function at the current position. + gradient_estimator + A function that takes a position, a batch of data and returns an estimation + of the gradient of the log-density at this position. minibatch Minibatch of data. step_size_diff @@ -123,7 +127,7 @@ def one_step( / energy_gap ) - logprob_grad = jax.grad(logdensity_estimator_fn)(position, minibatch) + logprob_grad = gradient_estimator(position, minibatch) position = integrator( rng_key, position, @@ -133,7 +137,7 @@ def one_step( ) # Update the stochastic approximation to the energy histogram - neg_logprob = -logdensity_estimator_fn(position, minibatch) + neg_logprob = -logdensity_estimator(position, minibatch) idx = jax.lax.min( jax.lax.max( jax.lax.floor((neg_logprob - min_energy) / energy_gap + 1).astype( diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index a1d4f51bf..81c8ecf81 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -249,7 +249,10 @@ def test_linear_regression_contour_sgld(self): logdensity_fn = blackjax.sgmcmc.logdensity_estimator( self.logprior_fn, self.loglikelihood_fn, data_size ) - csgld = blackjax.csgld(logdensity_fn) + grad_fn = blackjax.sgmcmc.grad_estimator( + self.logprior_fn, self.loglikelihood_fn, data_size + ) + csgld = blackjax.csgld(logdensity_fn, grad_fn) _, rng_key = jax.random.split(rng_key) data_batch = X_data[:100, :]