Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass grad_estimator to the CSGLD kernel directly #518

Merged
merged 1 commit into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,9 +668,12 @@ class csgld:

Parameters
----------
logdensity_estimator_fn
logdensity_estimator
A function that returns an estimation of the model's logdensity given
a position and a batch of data.
gradient_estimator
A function that takes a position, a batch of data and returns an estimation
of the gradient of the log-density at this position.
zeta
Hyperparameter that controls the geometric property of the flattened
density. If `zeta=0` the function reduces to the SGLD step function.
Expand Down Expand Up @@ -700,7 +703,8 @@ class csgld:

def __new__( # type: ignore[misc]
cls,
logdensity_estimator_fn: Callable,
logdensity_estimator: Callable,
gradient_estimator: Callable,
zeta: float = 1,
temperature: float = 0.01,
num_partitions: int = 512,
Expand All @@ -722,7 +726,8 @@ def step_fn(
return step(
rng_key,
state,
logdensity_estimator_fn,
logdensity_estimator,
gradient_estimator,
minibatch,
step_size_diff,
step_size_stoch,
Expand Down
12 changes: 8 additions & 4 deletions blackjax/sgmcmc/csgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def kernel(num_partitions=512, energy_gap=10, min_energy=0) -> Callable:
def one_step(
rng_key: PRNGKey,
state: ContourSGLDState,
logdensity_estimator_fn: Callable,
logdensity_estimator: Callable,
gradient_estimator: Callable,
minibatch: PyTree,
step_size_diff: float, # step size for Langevin diffusion
step_size_stoch: float = 1e-3, # step size for stochastic approximation
Expand Down Expand Up @@ -95,9 +96,12 @@ def one_step(
State of the pseudo-random number generator.
state
Current state of the CSGLD sampler
logdensity_estimator_fn
logdensity_estimator
Function that returns an estimation of the value of the density
function at the current position.
gradient_estimator
A function that takes a position, a batch of data and returns an estimation
of the gradient of the log-density at this position.
minibatch
Minibatch of data.
step_size_diff
Expand All @@ -123,7 +127,7 @@ def one_step(
/ energy_gap
)

logprob_grad = jax.grad(logdensity_estimator_fn)(position, minibatch)
logprob_grad = gradient_estimator(position, minibatch)
position = integrator(
rng_key,
position,
Expand All @@ -133,7 +137,7 @@ def one_step(
)

# Update the stochastic approximation to the energy histogram
neg_logprob = -logdensity_estimator_fn(position, minibatch)
neg_logprob = -logdensity_estimator(position, minibatch)
idx = jax.lax.min(
jax.lax.max(
jax.lax.floor((neg_logprob - min_energy) / energy_gap + 1).astype(
Expand Down
5 changes: 4 additions & 1 deletion tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,10 @@ def test_linear_regression_contour_sgld(self):
logdensity_fn = blackjax.sgmcmc.logdensity_estimator(
self.logprior_fn, self.loglikelihood_fn, data_size
)
csgld = blackjax.csgld(logdensity_fn)
grad_fn = blackjax.sgmcmc.grad_estimator(
self.logprior_fn, self.loglikelihood_fn, data_size
)
csgld = blackjax.csgld(logdensity_fn, grad_fn)

_, rng_key = jax.random.split(rng_key)
data_batch = X_data[:100, :]
Expand Down