Skip to content

Commit

Permalink
Merge branch 'main' into refactor-mala-mgrad
Browse files Browse the repository at this point in the history
  • Loading branch information
albcab authored Apr 14, 2023
2 parents 6ba04f9 + 85a67e6 commit cadafb2
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
21 changes: 12 additions & 9 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,6 @@ class sgld:
.. code::
sgld = blackjax.sgld(grad_fn)
state = sgld.init(position)
Assuming we have an iterator `batches` that yields batches of data we can
perform one step:
Expand All @@ -544,14 +543,14 @@ class sgld:
step_size = 1e-3
minibatch = next(batches)
new_state = sgld.step(rng_key, state, minibatch, step_size)
new_position = sgld.step(rng_key, position, minibatch, step_size)
Kernels are not jit-compiled by default so you will need to do it manually:
.. code::
step = jax.jit(sgld.step)
new_state, info = step(rng_key, state, minibatch, step_size)
new_position, info = step(rng_key, position, minibatch, step_size)
Parameters
----------
Expand Down Expand Up @@ -611,7 +610,6 @@ class sghmc:
.. code::
sghmc = blackjax.sghmc(grad_estimator, num_integration_steps)
state = sghmc.init(position)
Assuming we have an iterator `batches` that yields batches of data we can
perform one step:
Expand All @@ -620,14 +618,14 @@ class sghmc:
step_size = 1e-3
minibatch = next(batches)
new_state = sghmc.step(rng_key, state, minibatch, step_size)
new_position = sghmc.step(rng_key, position, minibatch, step_size)
Kernels are not jit-compiled by default so you will need to do it manually:
.. code::
step = jax.jit(sghmc.step)
new_state, info = step(rng_key, state, minibatch, step_size)
new_position, info = step(rng_key, position, minibatch, step_size)
Parameters
----------
Expand Down Expand Up @@ -668,9 +666,12 @@ class csgld:
Parameters
----------
logdensity_estimator_fn
logdensity_estimator
A function that returns an estimation of the model's logdensity given
a position and a batch of data.
gradient_estimator
A function that takes a position, a batch of data and returns an estimation
of the gradient of the log-density at this position.
zeta
Hyperparameter that controls the geometric property of the flattened
density. If `zeta=0` the function reduces to the SGLD step function.
Expand Down Expand Up @@ -700,7 +701,8 @@ class csgld:

def __new__( # type: ignore[misc]
cls,
logdensity_estimator_fn: Callable,
logdensity_estimator: Callable,
gradient_estimator: Callable,
zeta: float = 1,
temperature: float = 0.01,
num_partitions: int = 512,
Expand All @@ -722,7 +724,8 @@ def step_fn(
return step(
rng_key,
state,
logdensity_estimator_fn,
logdensity_estimator,
gradient_estimator,
minibatch,
step_size_diff,
step_size_stoch,
Expand Down
12 changes: 8 additions & 4 deletions blackjax/sgmcmc/csgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def kernel(num_partitions=512, energy_gap=10, min_energy=0) -> Callable:
def one_step(
rng_key: PRNGKey,
state: ContourSGLDState,
logdensity_estimator_fn: Callable,
logdensity_estimator: Callable,
gradient_estimator: Callable,
minibatch: PyTree,
step_size_diff: float, # step size for Langevin diffusion
step_size_stoch: float = 1e-3, # step size for stochastic approximation
Expand Down Expand Up @@ -95,9 +96,12 @@ def one_step(
State of the pseudo-random number generator.
state
Current state of the CSGLD sampler
logdensity_estimator_fn
logdensity_estimator
Function that returns an estimation of the value of the density
function at the current position.
gradient_estimator
A function that takes a position, a batch of data and returns an estimation
of the gradient of the log-density at this position.
minibatch
Minibatch of data.
step_size_diff
Expand All @@ -123,7 +127,7 @@ def one_step(
/ energy_gap
)

logprob_grad = jax.grad(logdensity_estimator_fn)(position, minibatch)
logprob_grad = gradient_estimator(position, minibatch)
position = integrator(
rng_key,
position,
Expand All @@ -133,7 +137,7 @@ def one_step(
)

# Update the stochastic approximation to the energy histogram
neg_logprob = -logdensity_estimator_fn(position, minibatch)
neg_logprob = -logdensity_estimator(position, minibatch)
idx = jax.lax.min(
jax.lax.max(
jax.lax.floor((neg_logprob - min_energy) / energy_gap + 1).astype(
Expand Down
5 changes: 4 additions & 1 deletion tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,10 @@ def test_linear_regression_contour_sgld(self):
logdensity_fn = blackjax.sgmcmc.logdensity_estimator(
self.logprior_fn, self.loglikelihood_fn, data_size
)
csgld = blackjax.csgld(logdensity_fn)
grad_fn = blackjax.sgmcmc.grad_estimator(
self.logprior_fn, self.loglikelihood_fn, data_size
)
csgld = blackjax.csgld(logdensity_fn, grad_fn)

_, rng_key = jax.random.split(rng_key)
data_batch = X_data[:100, :]
Expand Down

0 comments on commit cadafb2

Please sign in to comment.