-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Resolve 'Functions to run kernels' #598
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you revert the change in the quick start example (it is good to demonstrate how the basic implementation works), and apply the function in the tests where similar helper inference loop function are used?
@junpenglao @albcab ran into an issue while going through the tests converting This is what is happening: here is a univariate normal test for samplers. This particular case uses import jax
import jax.numpy as jnp
import blackjax
algo = blackjax.elliptical_slice(lambda _: 1.0, **{"cov": jnp.array([2.0**2]), "mean": 1.0})
rng_key = jax.random.PRNGKey(123)
initial_position = 5.0
initial_state = blackjax.elliptical_slice.init(initial_position, lambda _: 1.0)
print(f"Initial State before try|except block {initial_state}")
# THIS IS BASICALLY WHAT `run_inference_algorithm` is doing.
try:
initial_state = blackjax.elliptical_slice.init(initial_state, lambda _: 1.0)
except TypeError:
pass
print(f"Initial State after try|except block: {initial_state}")
'Initial State before try|except block EllipSliceState(position=5.0, logdensity=1.0)'
'Initial State after try|except block: EllipSliceState(position=EllipSliceState(position=5.0, logdensity=1.0), logdensity=1.0)' The comment indicates that the The problem is that the Sorry if this is a bit disorganized 😅 but basically I want to know what you both @junpenglao @albcab would like to do about this. Should the test not include the dummy logdensity fn? This would be the quickest fix I think. Or do we need to rethink the If the latter, I still think this could be a useful option: def run_inference_algorithm(
rng_key: PRNGKey,
inference_algorithm: Union[SamplingAlgorithm, VIAlgorithm],
num_steps: int,
initial_position: ArrayLikeTree = None,
initial_state: ArrayLikeTree = None,
) -> Tuple[State, State, Info]:
"""
Wrapper to run an inference algorithm.
Parameters
----------
rng_key : PRNGKey
The random state used by JAX's random numbers generator.
inference_algorithm : Union[SamplingAlgorithm, VIAlgorithm]
One of blackjax's sampling algorithms or variational inference algorithms.
num_steps : int
Number of learning steps.
initial_position : ArrayLikeTree, optional
The initial position to initialize the state of an inference algorithm,
by default None. Note that either `initial_position` or `initial_state`
must be passed in, but not both.
initial_state: ArrayLikeTree, optional
The initial state of the inference algorithm, by default None.
Note that either `initial_position` or `initial_state` must be passed in,
but not both.
Returns
-------
Tuple[State, State, Info]
1. The final state of the inference algorithm.
2. The history of states of the inference algorithm.
3. The history of the info of the inference algorithm.
"""
if (initial_position is None) == (initial_state is None):
raise ValueError(
"Either `initial_position` or `initial_state` must be specified, but not both."
)
if initial_state is None:
initial_state = inference_algorithm.init(initial_position)
keys = jax.random.split(rng_key, num_steps)
@jax.jit
def one_step(state, rng_key):
state, info = inference_algorithm.step(rng_key, state)
return state, (state, info)
final_state, (state_history, info_history) = jax.lax.scan(
one_step, initial_state, keys
)
return final_state, state_history, info_history @junpenglao you mentioned that having jax.vmap(lambda rng_key, initial_position: run_inference_algorithm(rng_key, inference_algorithm, num_steps, initial_position=initial_position)(rng_keys, initial_positions) I am also a beginner so I could be wrong; I am trying to learn so apologies in advance! Thanks for all the help :) |
We should modify the test instead, try: |
@junpenglao I think it is ready for review |
Could you also update this one:
|
…s for run_inference wrapper
…wrapper run_inference_algorithm
Nice work, thank you! |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #598 +/- ##
=======================================
Coverage 99.16% 99.16%
=======================================
Files 54 54
Lines 2513 2527 +14
=======================================
+ Hits 2492 2506 +14
Misses 21 21 ☔ View full report in Codecov by Sentry. |
Thank you for your contribution @PaulScemama! And congrats on your first PR to Blackjax :-) Could you also open a PR to change the inference loop usage in sampling-book? |
@junpenglao thank you! :) Thanks for all the guidance. The library is wonderful and I'm excited to continue contributing. I will open a PR to change the inference loop usage in the sampling book. |
* Add function, modify to account for change * Revert back quickstart.md * Add run_inference wrapper to tests/mcmc/sampling; Get rid of arg types for run_inference wrapper * Get rid of unused imports * Add run_inference wrapper to tests/benchmark * Change import style * Import style for benchmarks test; add wrapper to adaptation test * Replace 'kernel' variable name with 'inference algorithm' when using wrapper run_inference_algorithm --------- Co-authored-by: Paul Scemama <[email protected]>
Thank you for opening a PR!
Closes #591. Implements a
run_inference_algorithm
wrapper function inblackjax/util.py
for convenience. See discussion in #591 for more details.A few important guidelines and requirements before we can merge your PR:
main
commit;pre-commit
is installed and configured on your machine, and you ran it before opening the PR;Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors.