You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
From this paper. All the necessary ingredients are here and in Optax. This would give us a deep learning oriented example while demonstrating the flexibility of the kernel design.
As a composite kernel, we should aim at making the design similar to other blackjax kernels. It should also be usable with any Optax optimiser and any SgMCMC algorithm:
fromtypingimportNamedTuple, tupleclassCyclicalSgMCMCState(NamedTuple):
position: PyTreesgmcmc_state: SgMCMCStateopt_state: OptaxStateclassScheduleState(NamedTuple):
do_sample: boolis_sequence_start: boolstep_size: float# This kernel can be built by passing # a SgMCMC kernel, Optax kernel and # schedule by closuredefcyclical_sgmcmc_kernel(rng_key, schedule_state: ScheduleState, state: CyclicalSgMCMCState, minibatch) ->Tuple[CyclicalSgMCMCtate, Info]:
defsgmcmc_init_fn(state):
raiseNotImplementedErrorstate=jax.lax.cond(
schedule_state.do_sample&schedule_state.is_sequence_start,
sgmcmc_init_fn,
lambdax: x,
state,
)
defopt_init_fn(state):
raiseNotImplementedErrorstate=jax.lax.cond(
!schedule_state.do_sample&schedule_state.is_sequence_start,
opt_init_fn,
lambdax: x,
state,
)
defsgmcmc_update_fn(rng_key, state, minibatch, step_size):
raiseNotImplementedErrordefopt_update_fn(_, state, minibatch, step_size):
raiseNotImplementedErrornew_state=jax.lax.cond(
schedule_state.do_sample,
sgmcmc_update_fn,
opt_update_fn,
(rng_key, state, minibatch, schedule_state.step_size)
)
# info = # Must be something that's common to both# SgMCMC and Optax or callers will# not be able to use `jax.lax.scan`returnnew_state, (schedule.do_sample, schedule.step_size, info)
We can additionally add the ScheduleState to the CyclicalSgMCMCState to make the algorithm more self-contained.
From this paper. All the necessary ingredients are here and in Optax. This would give us a deep learning oriented example while demonstrating the flexibility of the kernel design.
There is an implementation in
torch
here: https://github.com/ruqizhang/csgmcmcThe text was updated successfully, but these errors were encountered: