-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SGNHT #515
Add SGNHT #515
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR! Sorry for the delay, busy doing some restructuring.
Everything looks good, docs are great, tests are good, and the general structure of the algorithm looks great.
I agree that the parametrization from Ma et al for SGHMC is best, let's keep the rescaling.
You'll only need to follow the new structure of the API introduced in #501 and rebase to the latest commit. All you really need to do is move what you put in kernels.py
to sgmcmc/snht.py
and use the naming discussed in #280 for your algorithms. Everything else should be rebased to the latest commit history without any conflicts. You probably don't need it, but here is a basic skeleton for the new structure of sampling algorithms.
Since sgld
and sghmc
don't need a state (only the position needs to be carried over to the next iteration) they also don't need an initializer, hence why they don't return a MCMCSamplingAlgorithm
. For now, we'll just keep it like this. Even though returning the momentum could be useful for debugging (if someone complains we'll change it).
I can't believe I was just getting the hang of the BlackJAX structure and then you went and changed it! 😆 Just kidding, big fan of the changes. Hopefully, I have adopted them correctly, let me know what you think. As discussed I have included the rescaling of the parameters for SGHMC. I also updated the docs for SGLC and SGHMC to say they return a "step function" rather than an "MCMCSamplingAlgorithm", but not sure if there is a better terminology. IMO it might be nice for SGLD and SGHMC to return MCMCSamplingAlgorithm (perhaps with a dummy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just two easy naming changes and we are ready to merge!
IMO it might be nice for SGLD and SGHMC to return MCMCSamplingAlgorithm (perhaps with a dummy
init
function) so that users can switch between samplers seamlessly.
This is an excellent point, completely agree. If you can/want to do the changes, go ahead with another PR (leave the doc updates of this PR), else open an issue and we'll take care of it 👍
Good spot! Should be good to go now Yep I'm happy to do a new PR unifying SGLD and SGHMC to also return an |
Codecov Report
@@ Coverage Diff @@
## main #515 +/- ##
==========================================
+ Coverage 99.28% 99.30% +0.02%
==========================================
Files 47 48 +1
Lines 1947 2021 +74
==========================================
+ Hits 1933 2007 +74
Misses 14 14
|
* add sgnht * reformat * Restructure kernels * Reformat * Clean * Rename step to kernel
Add Stochastic gradient Nosé-Hoover thermostat sampler of Ding et al, discussed in #289.
I have also rescaled the
alpha
,beta
parameters forsghmc
(and the newsgnht
) to align with the description in Ma et al, which IMO has better interpretability (beta
is variance of stochastic gradient). But I understand that this could modify existing code so can be reverted if need be. I also refactoreddiffusions.sghmc
to read more clearly as an Euler solver.I also noticed that the
sgld
andsghmc
kernels return thestep
functions themselves rather than anMCMCSamplingAlgorithm
(and that this is in disagreement with the docs).A few important guidelines and requirements before we can merge your PR:
main
commit;pre-commit
is installed and configured on your machine, and you ran it before opening the PR;