Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return adaptation extra information #466

Merged
merged 3 commits into from
Jan 17, 2023

Conversation

rlouf
Copy link
Member

@rlouf rlouf commented Jan 16, 2023

We currently only return the last state, the values of the parameter and the adapted kernel. However, the full chain and intermediate adaptation states can be useful when debugging inference.

In this PR we make window_adaptation, meads_adaptation and pathfinder_adaptation return this extra information. Closes #401.

The adaptation state and info are returned in the following NamedTuples:

class AdaptationResults(NamedTuple):
    state: PyTree
    kernel: Callable
    parameters: dict


class AdaptationInfo(NamedTuple):
    state: NamedTuple
    info: NamedTuple
    adaptation_state: NamedTuple

so the adaptation is ran e.g. as:

warmup = blackjax.window_adaptation(blackjax.nuts. logdensity_fn)
(state, kernel, parameters), info = warmup.run(rng_key, initial_position, 1000)

# OR
results, info = warmup.run(rng_key, initial_position, 1000)
state = results.state
kernel = results.kernel

IMO this feels clunky and a bit too high-level for Blackjax, mainly because of the fact that we return a kernel. I'm wondering if we should deprecate this way of doing things and favor instead the more explicit:

warmup = blackjax.window_adaptation(blackjax.nuts. logdensity_fn)
(state, parameters), info = warmup.run(rng_key, initial_position, 1000)
nuts = blackjax.nuts(logdensity_fn, **parameters)
hmc = blackjax.hmc(logdensity_fn, num_integration_steps=10, **parameters)  # we can also do this

@rlouf rlouf added refactoring Change that adds no functionality but improves code quality adaptation Issue related to the adaptation labels Jan 16, 2023
@rlouf rlouf requested a review from junpenglao January 16, 2023 09:35
@rlouf rlouf force-pushed the return-warmup-info branch from efa0d99 to 130e322 Compare January 16, 2023 09:37
@junpenglao
Copy link
Member

+1 to deprecating returning a kernel. We had similar discussion some time back because returning kernel makes it difficult to vmap the warmup result (hence motivated the change #276).

I think it should be reasonable to do this instead:

warmup = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
(state, parameters), info = warmup.run(rng_key, initial_position, 1000)
kernel = warmup._kernel_fn(**parameters)

# The below is not garentee to work (only works if the kernel are within the same family like hmc/nuts)
nuts = blackjax.nuts(logdensity_fn, **parameters)
hmc = blackjax.hmc(logdensity_fn, num_integration_steps=10, **parameters)

WDYT?

@codecov
Copy link

codecov bot commented Jan 16, 2023

Codecov Report

Merging #466 (5f615bc) into main (a4ee853) will decrease coverage by 0.01%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main     #466      +/-   ##
==========================================
- Coverage   99.16%   99.16%   -0.01%     
==========================================
  Files          48       48              
  Lines        1923     1918       -5     
==========================================
- Hits         1907     1902       -5     
  Misses         16       16              
Impacted Files Coverage Δ
blackjax/__init__.py 100.00% <ø> (ø)
blackjax/adaptation/meads_adaptation.py 100.00% <ø> (ø)
blackjax/adaptation/__init__.py 100.00% <100.00%> (ø)
blackjax/kernels.py 99.18% <100.00%> (-0.02%) ⬇️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@junpenglao
Copy link
Member

Basically, the abstraction being (mutation at different step):

# API level MCMC kernel
next_key, state = kernel.one_step(rng_key, last_state, **param)
# API level warmup kernel
next_param = tuning.one_step(param, state)

# User level sampling
tuned_param = blackjax.some_warmup_routine(...)
kernel_one_step = kernel.one_step(rng_key, last_state, **tuned_param)
for i in range(total_samples):
    rng_key, state = kernel_one_step(rng_key, last_state)

Comment on lines 765 to +767
(new_state, new_adaptation_state),
(new_state, info, new_adaptation_state),
AdaptationInfo(new_state, info, new_adaptation_state),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here (and below) we are returning 2 tuple/namedtuple with similar structure - should we just return AdaptationInfo instead and do unpacking below?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are you suggesting here exactly? I'm not sure I understand.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am suggesting to return

        return AdaptationInfo(new_state, info, new_adaptation_state)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like to only carry-over the values that are necessary in the scan loop. This often leads to repetitions but hasn't bothered me so far. What bothers you here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right that's for using in scan - yeah than it is almost unavoidable

@rlouf
Copy link
Member Author

rlouf commented Jan 16, 2023

+1 to deprecating returning a kernel. We had similar discussion some time back because returning kernel makes it difficult to vmap the warmup result (hence motivated the change #276).

I think it should be reasonable to do this instead:

kernel = warmup._kernel_fn(**parameters)

We might as well call blackjax.nuts directly here, warmup._kernel_fn is not doing any work.

junpenglao
junpenglao previously approved these changes Jan 16, 2023
@rlouf rlouf force-pushed the return-warmup-info branch 8 times, most recently from e76f9e3 to 5f615bc Compare January 17, 2023 13:42
rlouf added 3 commits January 17, 2023 14:43
We currently only return the last state, the values of the parameter and
the adapted kernel. However, the full chain and intermediate adaptation
states can be useful when debugging inference.

In addition, adaptation currently returns a `kernel` where the
parameters have already been specified. This is however a bit to high
level for Blackjax and can make vmap-ing adaptation difficult.

Finally, MEADS is currently only implemented as an adaptation scheme for
GHMC, we change its name to reflect this.

In this PR we make `window_adaptation`, `meads_adaptation` and
`pathfinder_adaptation` return extra information, and do not return the
kernel directly anymore.
@rlouf rlouf merged commit c0f9687 into blackjax-devs:main Jan 17, 2023
@rlouf rlouf deleted the return-warmup-info branch January 17, 2023 14:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
adaptation Issue related to the adaptation refactoring Change that adds no functionality but improves code quality
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Return sampling info with the window adaptation
2 participants