-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Return adaptation extra information #466
Conversation
efa0d99
to
130e322
Compare
+1 to deprecating returning a kernel. We had similar discussion some time back because returning kernel makes it difficult to vmap the warmup result (hence motivated the change #276). I think it should be reasonable to do this instead: warmup = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
(state, parameters), info = warmup.run(rng_key, initial_position, 1000)
kernel = warmup._kernel_fn(**parameters)
# The below is not garentee to work (only works if the kernel are within the same family like hmc/nuts)
nuts = blackjax.nuts(logdensity_fn, **parameters)
hmc = blackjax.hmc(logdensity_fn, num_integration_steps=10, **parameters) WDYT? |
Codecov Report
@@ Coverage Diff @@
## main #466 +/- ##
==========================================
- Coverage 99.16% 99.16% -0.01%
==========================================
Files 48 48
Lines 1923 1918 -5
==========================================
- Hits 1907 1902 -5
Misses 16 16
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
Basically, the abstraction being (mutation at different step): # API level MCMC kernel
next_key, state = kernel.one_step(rng_key, last_state, **param)
# API level warmup kernel
next_param = tuning.one_step(param, state)
# User level sampling
tuned_param = blackjax.some_warmup_routine(...)
kernel_one_step = kernel.one_step(rng_key, last_state, **tuned_param)
for i in range(total_samples):
rng_key, state = kernel_one_step(rng_key, last_state) |
(new_state, new_adaptation_state), | ||
(new_state, info, new_adaptation_state), | ||
AdaptationInfo(new_state, info, new_adaptation_state), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here (and below) we are returning 2 tuple/namedtuple with similar structure - should we just return AdaptationInfo
instead and do unpacking below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are you suggesting here exactly? I'm not sure I understand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am suggesting to return
return AdaptationInfo(new_state, info, new_adaptation_state)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like to only carry-over the values that are necessary in the scan
loop. This often leads to repetitions but hasn't bothered me so far. What bothers you here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh right that's for using in scan
- yeah than it is almost unavoidable
kernel = warmup._kernel_fn(**parameters) We might as well call |
e76f9e3
to
5f615bc
Compare
We currently only return the last state, the values of the parameter and the adapted kernel. However, the full chain and intermediate adaptation states can be useful when debugging inference. In addition, adaptation currently returns a `kernel` where the parameters have already been specified. This is however a bit to high level for Blackjax and can make vmap-ing adaptation difficult. Finally, MEADS is currently only implemented as an adaptation scheme for GHMC, we change its name to reflect this. In this PR we make `window_adaptation`, `meads_adaptation` and `pathfinder_adaptation` return extra information, and do not return the kernel directly anymore.
We currently only return the last state, the values of the parameter and the adapted kernel. However, the full chain and intermediate adaptation states can be useful when debugging inference.
In this PR we make
window_adaptation
,meads_adaptation
andpathfinder_adaptation
return this extra information. Closes #401.The adaptation state and info are returned in the following
NamedTuple
s:so the adaptation is ran e.g. as:
IMO this feels clunky and a bit too high-level for Blackjax, mainly because of the fact that we return a kernel. I'm wondering if we should deprecate this way of doing things and favor instead the more explicit: