Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the control variates gradient estimator #299

Merged
merged 9 commits into from
Nov 20, 2022

Conversation

rlouf
Copy link
Member

@rlouf rlouf commented Sep 19, 2022

We add the control variates gradient estimator for stochastic gradient MCMC algorithm. Control Variates require one gradient estimation on the whole dataset, which begs two questions that may be answered in subsequent PRs:

  1. Shouldn't we compute logposterior_center in a separate init function, and thus propagate a GradientState?
  2. How can we let users distribute this computation as they wish? This operation may need distributing with pmap and we should allow that.

This PR is part of an effort to port SGMCMCJAX to blackjax, see #289.

@rlouf rlouf force-pushed the cv-gradient-estimator branch from 400ecdd to 45375a0 Compare September 19, 2022 11:03
@codecov
Copy link

codecov bot commented Sep 19, 2022

Codecov Report

Merging #299 (8684689) into main (becd2d2) will decrease coverage by 0.05%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main     #299      +/-   ##
==========================================
- Coverage   89.79%   89.73%   -0.06%     
==========================================
  Files          45       45              
  Lines        2166     2134      -32     
==========================================
- Hits         1945     1915      -30     
+ Misses        221      219       -2     
Impacted Files Coverage Δ
blackjax/kernels.py 99.55% <100.00%> (+0.78%) ⬆️
blackjax/mcmc/diffusions.py 100.00% <100.00%> (ø)
blackjax/mcmc/mala.py 100.00% <100.00%> (ø)
blackjax/sgmcmc/__init__.py 100.00% <100.00%> (ø)
blackjax/sgmcmc/diffusions.py 100.00% <100.00%> (ø)
blackjax/sgmcmc/gradients.py 100.00% <100.00%> (ø)
blackjax/sgmcmc/sghmc.py 100.00% <100.00%> (ø)
blackjax/sgmcmc/sgld.py 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@rlouf rlouf force-pushed the cv-gradient-estimator branch from b3d5ee9 to ad0802d Compare September 24, 2022 22:41
@rlouf rlouf force-pushed the cv-gradient-estimator branch from ad0802d to 5410741 Compare October 7, 2022 12:52
@junpenglao junpenglao self-assigned this Oct 7, 2022
@rlouf rlouf requested a review from junpenglao October 7, 2022 13:58
@rlouf rlouf self-assigned this Oct 7, 2022
@rlouf
Copy link
Member Author

rlouf commented Oct 7, 2022

A few thoughts:

  • I'm on the fence when it comes to keeping CV and implementing SVRG; not only do they require to keep track of a GradientState, they also need the full dataset at initialization (CV) or at every step (SVRG) which can be prohibitive in some scenarios. I would only consider keeping them in an Optax-like API where gradients are computed outside of the integrators. Optax has a control variates API we can get inspiration from.
  • We cannot adopt an Optax-like API because some palindromic integrators (BADODAB for instance) require two gradient evaluations per step. Maybe there's something to learn from Optax's MultiStep interface?
  • Second-order methods like AMAGOLD Implement AMAGOLD #375 may put additional constraints on the API so we may want to sketch an implementation before moving forward.

All in all, if we didn't have methods where several gradient evaluations are needed to get one sample our life would be much easier. But also, maybe, none of that matters since complexity only affects the internals.

Comment on lines 123 to 125
logposterior_grad_estimator_fn = grad_estimator(
logprior_fn, loglikelihood_fn, data_size
).estimate
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logposterior_grad_estimator_fn = grad_estimator(
logprior_fn, loglikelihood_fn, data_size
).estimate
simple_grad_estimator = grad_estimator(
logprior_fn, loglikelihood_fn, data_size
).estimate
logposterior_grad_estimator_fn = lambda *x: simple_grad_estimator(*x)[0]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not very happy of this suggestion as well so feel free to resolve.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that my code is not very satisfying too.

junpenglao
junpenglao previously approved these changes Oct 8, 2022
@junpenglao
Copy link
Member

junpenglao commented Oct 8, 2022

High level question: here you also revert the change in #293 where you turn step size into a callable (so user can control how step size change during sampling), what is the reasoning behind the revert?

@rlouf
Copy link
Member Author

rlouf commented Oct 8, 2022

High level question: here you also revert the change in #293 where you turn step size into a callable (so user can control how step size change during sampling), what is the reasoning behind the revert?

That's an important question. The current high-level interface of SgLD is:

sgld = blackjax.sgld(grad_estimator, schedule)

Or for a constant schedule

sgld = blackjax.sgld(grad_estimator, 1e-3)

And then to take a step

sgld.step(rng_key, state, minibatch)

Internally this forces us to increment a counter in the state, which I really dislike. For users this interface becomes very quickly impractical as I have seen when implementing Cyclical SgLD. Having the schedule baked in is also a common criticism of Optax. I much prefer the interface:

sgld = blackjax.sgld(grad_estimator)
...
state = sgld.step(rng_key, state, minibatch, step_size)

My ideal lower-level interface would be:

sgld = blackjax.sgld()
...
step_size = next(schedule)
minibatch = next(dataset)
gradients = grad_estimator(state, minibatch)
state = sgld.step(rng_key, state, gradients, step_size)

But I've expressed why that's difficult above.

@junpenglao
Copy link
Member

junpenglao commented Oct 8, 2022

So for low to mid level usage, user might be able to do something like:

cosine_decay_scheduler = optax.cosine_decay_schedule(0.0001, decay_steps=total_steps, alpha=0.95)

for i in ...:  # could be in a jax.scan as well
  step_size = cosine_decay_scheduler(i)
  minibatch = ...
  gradients = grad_estimator(state, minibatch)
  state = sgld.step(rng_key, state, gradients, step_size)

control_variate,
logposterior_grad_estimate,
grad_estimator_state.control_variate_grad,
logposterior_grad_center_estimate,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

I've been implementing control variates (and the SVRG algorithm) within BlackJax and I'm happy to see that you've been recently working on them as well. I think that there is a mistake in the tree_map at the end of the new gradient estimator. The full centered gradient is supposed to be given as the third input of your control_variate function whereas it is given as the second input in the tree_map. Is this correct ? Thanks !

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'd be curious to see your implementation, especially SVRG!

Copy link

@bstaber bstaber Oct 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be happy to discuss about the SVRG algorithm because I'm not convinced by my implementation ! My implementation of the "cv" algorithm is similar to yours. I had trouble implementing the SVRG algorithm (with lax.cond for instance) but in the end I decided to implement it like a loop of n "CV" algorithms of m iterations each. Here is a snippet for the SGLD-SVRG algorithm where update_freq is the frequence at which the centered gradient is updated:

class sgldsvrg:

    init = staticmethod(sgmcmc.sgldcv.init)
    kernel = staticmethod(sgmcmc.sgldcv.kernel)

    def __new__(  # type: ignore[misc]
        cls,
        grad_estimator_fn: Callable,
        schedule_fn: Callable,
        train_dataset,
        batch_loader,
        update_freq: int,
    ) -> SamplingAlgorithm:

        step = cls.kernel(grad_estimator_fn)

        def init_fn(position: PyTree, c_position: PyTree, c_full_loglike_grad: PyTree, data_batch: PyTree):
            return cls.init(position, c_position, c_full_loglike_grad, data_batch, grad_estimator_fn)

        def step_fn(rng_key: PRNGKey, state):
            step_size = schedule_fn(state.step)
            def svrg_kernel_step(state, rng_key):
                batch = next(batch_loader)
                new_state = step(rng_key, state, batch, step_size)
                return new_state, new_state
            keys = jax.random.split(rng_key, update_freq)
            last_svrg_state, svrg_states = jax.lax.scan(svrg_kernel_step, state, keys)
            c_full_logprob_grad = grad_estimator_fn(last_svrg_state.position, train_dataset)
            updated_state = sgmcmc.sgldcv.SGLDCVState(last_svrg_state.step, 
                                                      last_svrg_state.position, 
                                                      last_svrg_state.batch_logprob_grad, 
                                                      last_svrg_state.position, 
                                                      c_full_logprob_grad, 
                                                      last_svrg_state.batch_logprob_grad)
            return updated_state, svrg_states
            
        return SamplingAlgorithm(init_fn, step_fn)

Do you think that this kind of implementation is efficient ? I had trouble making this work with a lax.cond for updating the centered gradients (like in SGMCMCJAX, if I'm not wrong).

Thanks a lot for this amazing package !

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took some time to think about the general API, and here is what I came up with for SVRG; everything else is a simplified version of this (no update for CV, no update and no init for the simple estimator). What do you think?

import jax
import blackjax


svrg = blackjax.sgmcmc.gradients.svrg_estimator(logprior_fn, loglikelihood_fn)
sghmc = blackjax.sghmc(svrg)

grad_state = svrg.init(centering_position, data)

minibatch = next(data)
step_size = next(schedule)
position, grad_state, info = sghmc.step(  # uses svrg.grad internally
    rng_key,
    position,
    grad_state,
    minibatch,
    step_size
)

# Perform several SGMCMC steps
# And update the control variate
grad_state = svrg.update(grad_state, position, data)

_, rng_key = jax.random.split(rng_key)
minibatch = next(data)
step_size = next(schedule)
position, grad_state, info = sghmc.step(
    rng_key,
    position,
    grad_state,
    minibatch,
    step_size
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Perform several SGMCMC steps
# And update the control variate
grad_state = svrg.update(grad_state, position, data)

Is this going to be in a for loop or the num of steps are already encoded in the init of svrg?

@rlouf
Copy link
Member Author

rlouf commented Oct 21, 2022

@junpenglao @bstaber The behavior in a loop was not very clear indeed, so let me give a full example. I modified (hopefully improved) the behavior slightly. For the simple Robbins-Monro estimator we have:

import jax
import blackjax
import blackax.sgmcmc.gradients as gradients # need an alias


schedule: Generator[float]
data = PyTree
batches: Generator[jax.numpy.DeviceArray]
position: PyTree

# Get the CV gradient estimator and SGHMC algorithm
grad_estimator = gradients.simple_estimator(logprior_fn, loglikelihood_fn, num_examples)
sghmc = blackjax.sgmcmc.sghmc()

rng_key = jax.random.PRNGKey(0)
for step in range(num_training_steps):
    _, rng_key = jax.random.split(rng_key)

    minibatch = next(batches)
    step_size = next(schedule)
    position, grad_state, info = sghmc.step(  # uses svrg.grad internally
        rng_key,
        position,
        grad_estimator,
        minibatch,
        step_size
    )

Now for the Control Variates estimator:

import jax
import blackjax
import blackax.sgmcmc.gradients as gradients # need an alias


schedule: Generator[float]
data = PyTree
batches: Generator[jax.numpy.DeviceArray]
position: PyTree
centering_position: PyTree


# Get the CV gradient estimator and SGHMC algorithm
cv = gradients.cv(logprior_fn, loglikelihood_fn, num_examples)
sghmc = blackjax.sgmcmc.sghmc()

# Initialize the gradient state
# (SGHMC state is simply the position)
grad_estimator = cv.init(centering_position, data)

rng_key = jax.random.PRNGKey(0)
for step in range(num_training_steps):
    _, rng_key = jax.random.split(rng_key)

    minibatch = next(batches)
    step_size = next(schedule)
    position, grad_state, info = sghmc.step(
        rng_key,
        position,
        grad_estimator,
        minibatch,
        step_size
    )

SVRG is a CV estimator with updates. @bstaber's intuition is correct, and we can re-use the same code as for CV; we just need to re-initialize the control variate every cv_update_rate steps:

import jax
import blackjax


schedule: Generator[float]
data = PyTree
batches: Generator[jax.numpy.DeviceArray]
position: PyTree
centering_position: PyTree
cv_update_rate: int


# Get the CV gradient estimator and SGHMC algorithm
svrg = gradients.cv(logprior_fn, loglikelihood_fn, num_examples)
sghmc = blackjax.sghmc(cv)

# Initialize the gradient state
# (SGHMC state is simply the position)
grad_estimator = svrg.init(centering_position, data)

rng_key = jax.random.PRNGKey(0)
for step in range(num_training_steps):
    _, rng_key = jax.random.split(rng_key)

    minibatch = next(batches)
    step_size = next(schedule)
    position, grad_state, info = sghmc.step(
        rng_key,
        position,
        grad_estimator,
        minibatch,
        step_size
    )

    # SVRG is nothing more than CV that you can update
    if step % == cv_update_rate:
        grad_estimator = svrg.init(centering_position, data)

While it is naively tempting to compute the gradient estimate outside of sghmc.step, SGHMC (and some other algorithms) needs to compute the gradient several times before returning a sample; the situation is not quite the same as Optax.

We may go a step further to remove the awkwardness that @junpenglao saw in the code, and make CV effectively a wrapper around the Robbins-Monro estimator:

grad_estimator = gradients.simple_estimator(logprior_fn, loglikelihood_fn, num_examples)
cv_grad_estimator = gradients.cv(grad_estimator, centering_position, data)

For SVRG, again, we need to rebuild the estimator every cv_update_rate states:

    if step % == cv_update_rate:
        cv_grad_estimator = gradients.cv(grad_estimator, position, data)

@rlouf
Copy link
Member Author

rlouf commented Oct 22, 2022

@junpenglao This is ready for review. I implemented the control variates as a wrapper around the simple estimator, as dicussed above. The code is much cleaner than before, and svrg is obtained by calling gradients.control_variates again on the gradient estimator within the sampling loop. So we can tick this off the list in #289 as well once this PR is merged.

@rlouf rlouf force-pushed the cv-gradient-estimator branch 4 times, most recently from f5fc4d1 to a766c22 Compare October 26, 2022 20:57
@rlouf rlouf added refactoring Change that adds no functionality but improves code quality sgmcmc Stochastic Gradient MCMC samplers enhancement New feature or request labels Oct 27, 2022
@rlouf rlouf requested a review from junpenglao October 28, 2022 08:35
@rlouf
Copy link
Member Author

rlouf commented Nov 9, 2022

@junpenglao ping

@rlouf rlouf mentioned this pull request Nov 10, 2022
12 tasks
@rlouf rlouf force-pushed the cv-gradient-estimator branch from 284adf0 to 9b2b73c Compare November 20, 2022 10:35
@rlouf rlouf force-pushed the cv-gradient-estimator branch from 9b2b73c to d996dbe Compare November 20, 2022 16:49
@rlouf rlouf merged commit e190313 into blackjax-devs:main Nov 20, 2022
@rlouf rlouf deleted the cv-gradient-estimator branch November 20, 2022 19:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request refactoring Change that adds no functionality but improves code quality sgmcmc Stochastic Gradient MCMC samplers
Projects
None yet
3 participants