-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add the control variates gradient estimator #299
Conversation
400ecdd
to
45375a0
Compare
Codecov Report
@@ Coverage Diff @@
## main #299 +/- ##
==========================================
- Coverage 89.79% 89.73% -0.06%
==========================================
Files 45 45
Lines 2166 2134 -32
==========================================
- Hits 1945 1915 -30
+ Misses 221 219 -2
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
b3d5ee9
to
ad0802d
Compare
ad0802d
to
5410741
Compare
A few thoughts:
All in all, if we didn't have methods where several gradient evaluations are needed to get one sample our life would be much easier. But also, maybe, none of that matters since complexity only affects the internals. |
blackjax/sgmcmc/gradients.py
Outdated
logposterior_grad_estimator_fn = grad_estimator( | ||
logprior_fn, loglikelihood_fn, data_size | ||
).estimate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logposterior_grad_estimator_fn = grad_estimator( | |
logprior_fn, loglikelihood_fn, data_size | |
).estimate | |
simple_grad_estimator = grad_estimator( | |
logprior_fn, loglikelihood_fn, data_size | |
).estimate | |
logposterior_grad_estimator_fn = lambda *x: simple_grad_estimator(*x)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not very happy of this suggestion as well so feel free to resolve.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that my code is not very satisfying too.
High level question: here you also revert the change in #293 where you turn step size into a callable (so user can control how step size change during sampling), what is the reasoning behind the revert? |
That's an important question. The current high-level interface of SgLD is: sgld = blackjax.sgld(grad_estimator, schedule) Or for a constant schedule sgld = blackjax.sgld(grad_estimator, 1e-3) And then to take a step sgld.step(rng_key, state, minibatch) Internally this forces us to increment a counter in the state, which I really dislike. For users this interface becomes very quickly impractical as I have seen when implementing Cyclical SgLD. Having the schedule baked in is also a common criticism of Optax. I much prefer the interface: sgld = blackjax.sgld(grad_estimator)
...
state = sgld.step(rng_key, state, minibatch, step_size) My ideal lower-level interface would be: sgld = blackjax.sgld()
...
step_size = next(schedule)
minibatch = next(dataset)
gradients = grad_estimator(state, minibatch)
state = sgld.step(rng_key, state, gradients, step_size) But I've expressed why that's difficult above. |
So for low to mid level usage, user might be able to do something like: cosine_decay_scheduler = optax.cosine_decay_schedule(0.0001, decay_steps=total_steps, alpha=0.95)
for i in ...: # could be in a jax.scan as well
step_size = cosine_decay_scheduler(i)
minibatch = ...
gradients = grad_estimator(state, minibatch)
state = sgld.step(rng_key, state, gradients, step_size) |
blackjax/sgmcmc/gradients.py
Outdated
control_variate, | ||
logposterior_grad_estimate, | ||
grad_estimator_state.control_variate_grad, | ||
logposterior_grad_center_estimate, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
I've been implementing control variates (and the SVRG algorithm) within BlackJax and I'm happy to see that you've been recently working on them as well. I think that there is a mistake in the tree_map at the end of the new gradient estimator. The full centered gradient is supposed to be given as the third input of your control_variate function whereas it is given as the second input in the tree_map. Is this correct ? Thanks !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I'd be curious to see your implementation, especially SVRG!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be happy to discuss about the SVRG algorithm because I'm not convinced by my implementation ! My implementation of the "cv" algorithm is similar to yours. I had trouble implementing the SVRG algorithm (with lax.cond for instance) but in the end I decided to implement it like a loop of n
"CV" algorithms of m
iterations each. Here is a snippet for the SGLD-SVRG algorithm where update_freq
is the frequence at which the centered gradient is updated:
class sgldsvrg:
init = staticmethod(sgmcmc.sgldcv.init)
kernel = staticmethod(sgmcmc.sgldcv.kernel)
def __new__( # type: ignore[misc]
cls,
grad_estimator_fn: Callable,
schedule_fn: Callable,
train_dataset,
batch_loader,
update_freq: int,
) -> SamplingAlgorithm:
step = cls.kernel(grad_estimator_fn)
def init_fn(position: PyTree, c_position: PyTree, c_full_loglike_grad: PyTree, data_batch: PyTree):
return cls.init(position, c_position, c_full_loglike_grad, data_batch, grad_estimator_fn)
def step_fn(rng_key: PRNGKey, state):
step_size = schedule_fn(state.step)
def svrg_kernel_step(state, rng_key):
batch = next(batch_loader)
new_state = step(rng_key, state, batch, step_size)
return new_state, new_state
keys = jax.random.split(rng_key, update_freq)
last_svrg_state, svrg_states = jax.lax.scan(svrg_kernel_step, state, keys)
c_full_logprob_grad = grad_estimator_fn(last_svrg_state.position, train_dataset)
updated_state = sgmcmc.sgldcv.SGLDCVState(last_svrg_state.step,
last_svrg_state.position,
last_svrg_state.batch_logprob_grad,
last_svrg_state.position,
c_full_logprob_grad,
last_svrg_state.batch_logprob_grad)
return updated_state, svrg_states
return SamplingAlgorithm(init_fn, step_fn)
Do you think that this kind of implementation is efficient ? I had trouble making this work with a lax.cond
for updating the centered gradients (like in SGMCMCJAX, if I'm not wrong).
Thanks a lot for this amazing package !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took some time to think about the general API, and here is what I came up with for SVRG; everything else is a simplified version of this (no update for CV, no update and no init for the simple estimator). What do you think?
import jax
import blackjax
svrg = blackjax.sgmcmc.gradients.svrg_estimator(logprior_fn, loglikelihood_fn)
sghmc = blackjax.sghmc(svrg)
grad_state = svrg.init(centering_position, data)
minibatch = next(data)
step_size = next(schedule)
position, grad_state, info = sghmc.step( # uses svrg.grad internally
rng_key,
position,
grad_state,
minibatch,
step_size
)
# Perform several SGMCMC steps
# And update the control variate
grad_state = svrg.update(grad_state, position, data)
_, rng_key = jax.random.split(rng_key)
minibatch = next(data)
step_size = next(schedule)
position, grad_state, info = sghmc.step(
rng_key,
position,
grad_state,
minibatch,
step_size
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Perform several SGMCMC steps
# And update the control variate
grad_state = svrg.update(grad_state, position, data)
Is this going to be in a for loop or the num of steps are already encoded in the init of svrg?
@junpenglao @bstaber The behavior in a loop was not very clear indeed, so let me give a full example. I modified (hopefully improved) the behavior slightly. For the simple Robbins-Monro estimator we have: import jax
import blackjax
import blackax.sgmcmc.gradients as gradients # need an alias
schedule: Generator[float]
data = PyTree
batches: Generator[jax.numpy.DeviceArray]
position: PyTree
# Get the CV gradient estimator and SGHMC algorithm
grad_estimator = gradients.simple_estimator(logprior_fn, loglikelihood_fn, num_examples)
sghmc = blackjax.sgmcmc.sghmc()
rng_key = jax.random.PRNGKey(0)
for step in range(num_training_steps):
_, rng_key = jax.random.split(rng_key)
minibatch = next(batches)
step_size = next(schedule)
position, grad_state, info = sghmc.step( # uses svrg.grad internally
rng_key,
position,
grad_estimator,
minibatch,
step_size
) Now for the Control Variates estimator: import jax
import blackjax
import blackax.sgmcmc.gradients as gradients # need an alias
schedule: Generator[float]
data = PyTree
batches: Generator[jax.numpy.DeviceArray]
position: PyTree
centering_position: PyTree
# Get the CV gradient estimator and SGHMC algorithm
cv = gradients.cv(logprior_fn, loglikelihood_fn, num_examples)
sghmc = blackjax.sgmcmc.sghmc()
# Initialize the gradient state
# (SGHMC state is simply the position)
grad_estimator = cv.init(centering_position, data)
rng_key = jax.random.PRNGKey(0)
for step in range(num_training_steps):
_, rng_key = jax.random.split(rng_key)
minibatch = next(batches)
step_size = next(schedule)
position, grad_state, info = sghmc.step(
rng_key,
position,
grad_estimator,
minibatch,
step_size
) SVRG is a CV estimator with updates. @bstaber's intuition is correct, and we can re-use the same code as for CV; we just need to re-initialize the control variate every import jax
import blackjax
schedule: Generator[float]
data = PyTree
batches: Generator[jax.numpy.DeviceArray]
position: PyTree
centering_position: PyTree
cv_update_rate: int
# Get the CV gradient estimator and SGHMC algorithm
svrg = gradients.cv(logprior_fn, loglikelihood_fn, num_examples)
sghmc = blackjax.sghmc(cv)
# Initialize the gradient state
# (SGHMC state is simply the position)
grad_estimator = svrg.init(centering_position, data)
rng_key = jax.random.PRNGKey(0)
for step in range(num_training_steps):
_, rng_key = jax.random.split(rng_key)
minibatch = next(batches)
step_size = next(schedule)
position, grad_state, info = sghmc.step(
rng_key,
position,
grad_estimator,
minibatch,
step_size
)
# SVRG is nothing more than CV that you can update
if step % == cv_update_rate:
grad_estimator = svrg.init(centering_position, data) While it is naively tempting to compute the gradient estimate outside of We may go a step further to remove the awkwardness that @junpenglao saw in the code, and make CV effectively a wrapper around the Robbins-Monro estimator: grad_estimator = gradients.simple_estimator(logprior_fn, loglikelihood_fn, num_examples)
cv_grad_estimator = gradients.cv(grad_estimator, centering_position, data) For SVRG, again, we need to rebuild the estimator every if step % == cv_update_rate:
cv_grad_estimator = gradients.cv(grad_estimator, position, data) |
1b567cc
to
f52a608
Compare
@junpenglao This is ready for review. I implemented the control variates as a wrapper around the simple estimator, as dicussed above. The code is much cleaner than before, and svrg is obtained by calling |
f5fc4d1
to
a766c22
Compare
@junpenglao ping |
This simplifies the solvers a lot.
This is impractical in practice.
284adf0
to
9b2b73c
Compare
9b2b73c
to
d996dbe
Compare
We add the control variates gradient estimator for stochastic gradient MCMC algorithm. Control Variates require one gradient estimation on the whole dataset, which begs two questions that may be answered in subsequent PRs:
logposterior_center
in a separateinit
function, and thus propagate aGradientState
?pmap
and we should allow that.This PR is part of an effort to port SGMCMCJAX to blackjax, see #289.