Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement cyclical SGLD #209

Closed
rlouf opened this issue May 19, 2022 · 2 comments
Closed

Implement cyclical SGLD #209

rlouf opened this issue May 19, 2022 · 2 comments
Assignees
Labels
documentation Improvements or additions to documentation help wanted Extra attention is needed
Milestone

Comments

@rlouf
Copy link
Member

rlouf commented May 19, 2022

From this paper. All the necessary ingredients are here and in Optax. This would give us a deep learning oriented example while demonstrating the flexibility of the kernel design.

There is an implementation in torch here: https://github.com/ruqizhang/csgmcmc

@rlouf rlouf added help wanted Extra attention is needed sampler Issue related to samplers labels May 19, 2022
@rlouf rlouf added documentation Improvements or additions to documentation and removed sampler Issue related to samplers labels Jun 24, 2022
@rlouf rlouf self-assigned this Oct 2, 2022
@rlouf
Copy link
Member Author

rlouf commented Oct 2, 2022

As a composite kernel, we should aim at making the design similar to other blackjax kernels. It should also be usable with any Optax optimiser and any SgMCMC algorithm:

from typing import NamedTuple, tuple


class CyclicalSgMCMCState(NamedTuple):
    position: PyTree
    sgmcmc_state: SgMCMCState
    opt_state: OptaxState

class ScheduleState(NamedTuple):
    do_sample: bool
    is_sequence_start: bool
    step_size: float

# This kernel can be built by passing 
# a SgMCMC kernel, Optax kernel and 
# schedule by closure
def cyclical_sgmcmc_kernel(rng_key, schedule_state: ScheduleState, state: CyclicalSgMCMCState, minibatch) -> Tuple[CyclicalSgMCMCtate, Info]:

    def sgmcmc_init_fn(state):
        raise NotImplementedError

    state = jax.lax.cond(
        schedule_state.do_sample & schedule_state.is_sequence_start,
        sgmcmc_init_fn,
        lambda x: x,
        state,
    )

    def opt_init_fn(state):
        raise NotImplementedError

    state = jax.lax.cond(
       !schedule_state.do_sample & schedule_state.is_sequence_start,
        opt_init_fn,
        lambda x: x,
        state,
    )

    def sgmcmc_update_fn(rng_key, state, minibatch, step_size):
        raise NotImplementedError

    def opt_update_fn(_, state, minibatch, step_size):
        raise NotImplementedError

    new_state = jax.lax.cond(
        schedule_state.do_sample,
        sgmcmc_update_fn,
        opt_update_fn,
        (rng_key, state, minibatch, schedule_state.step_size)
    )

    # info = 
    # Must be something that's common to both
    # SgMCMC and Optax or callers will
    # not be able to use  `jax.lax.scan`

    return new_state, (schedule.do_sample, schedule.step_size, info)

We can additionally add the ScheduleState to the CyclicalSgMCMCState to make the algorithm more self-contained.

@rlouf
Copy link
Member Author

rlouf commented Jan 13, 2023

This was done in the Sampling Book: https://blackjax-devs.github.io/sampling-book/algorithms/cyclical_sgld.html

@rlouf rlouf closed this as completed Jan 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant