Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve 'Functions to run kernels' #598

Merged
merged 9 commits into from
Dec 6, 2023
Merged

Conversation

PaulScemama
Copy link
Contributor

@PaulScemama PaulScemama commented Dec 1, 2023

Thank you for opening a PR!

Closes #591. Implements a run_inference_algorithm wrapper function in blackjax/util.py for convenience. See discussion in #591 for more details.

A few important guidelines and requirements before we can merge your PR:

  • We should be able to understand what the PR does from its title only;
  • There is a high-level description of the changes;
  • There are links to all the relevant issues, discussions and PRs;
  • The branch is rebased on the latest main commit;
  • Commit messages follow these guidelines;
  • The code respects the current naming conventions;
  • Docstrings follow the numpy style guide
  • pre-commit is installed and configured on your machine, and you ran it before opening the PR;
  • There are tests covering the changes;
  • The doc is up-to-date;
  • If I add a new sampler* I added/updated related examples

Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors.

Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you revert the change in the quick start example (it is good to demonstrate how the basic implementation works), and apply the function in the tests where similar helper inference loop function are used?

blackjax/util.py Outdated Show resolved Hide resolved
@PaulScemama
Copy link
Contributor Author

@junpenglao @albcab ran into an issue while going through the tests converting inference_loops to new run_inference_algorithm. This pertains to the discussion in the issue here.

This is what is happening: here is a univariate normal test for samplers. This particular case uses elliptical_slice and passes in a dummy logdensity_fn: lambda _: 1.0. Now consider this self-contained code

import jax
import jax.numpy as jnp

import blackjax

algo = blackjax.elliptical_slice(lambda _: 1.0, **{"cov": jnp.array([2.0**2]), "mean": 1.0})
rng_key = jax.random.PRNGKey(123)

initial_position = 5.0
initial_state = blackjax.elliptical_slice.init(initial_position, lambda _: 1.0)
print(f"Initial State before try|except block {initial_state}")

# THIS IS BASICALLY WHAT `run_inference_algorithm` is doing. 
try:
    initial_state = blackjax.elliptical_slice.init(initial_state, lambda _: 1.0)
except TypeError:
    pass

print(f"Initial State after try|except block: {initial_state}")

'Initial State before try|except block EllipSliceState(position=5.0, logdensity=1.0)'
'Initial State after try|except block: EllipSliceState(position=EllipSliceState(position=5.0, logdensity=1.0), logdensity=1.0)'

The comment indicates that the run_inference_algorithm uses a try|except block to decipher if the initial_position_or_state argument is an initial position or an initial state. If it is an initial position then the try block will work. If it is an initial state then the try block will fail due TypeError because of the assumption that the initial state object is different from the initial position object (which is a good assumption, indeed every init algorithm returns a custom NamedTuple State class while position is a PyTree).

The problem is that the init for elliptical_slice applies a logdensity_fn to the position argument. But in the test, a dummy logdensity_fn does not take any inputs and returns 1.0 regardless: it is lambda _: 1.0. This results in the unexpected and unwanted behavior where we can infinitely call init on any argument and it will not error out.

Sorry if this is a bit disorganized 😅 but basically I want to know what you both @junpenglao @albcab would like to do about this. Should the test not include the dummy logdensity fn? This would be the quickest fix I think. Or do we need to rethink the run_inference_algorithm?

If the latter, I still think this could be a useful option:

def run_inference_algorithm(
    rng_key: PRNGKey,
    inference_algorithm: Union[SamplingAlgorithm, VIAlgorithm],
    num_steps: int,
    initial_position: ArrayLikeTree = None,
    initial_state: ArrayLikeTree = None,
) -> Tuple[State, State, Info]:
    """
    Wrapper to run an inference algorithm.

    Parameters
    ----------
    rng_key : PRNGKey
        The random state used by JAX's random numbers generator.
    inference_algorithm : Union[SamplingAlgorithm, VIAlgorithm]
        One of blackjax's sampling algorithms or variational inference algorithms.
    num_steps : int
        Number of learning steps.
    initial_position : ArrayLikeTree, optional
        The initial position to initialize the state of an inference algorithm,
        by default None. Note that either `initial_position` or `initial_state`
        must be passed in, but not both.
    initial_state: ArrayLikeTree, optional
        The initial state of the inference algorithm, by default None.
        Note that either `initial_position` or `initial_state` must be passed in,
        but not both.

    Returns
    -------
    Tuple[State, State, Info]
        1. The final state of the inference algorithm.
        2. The history of states of the inference algorithm.
        3. The history of the info of the inference algorithm.
    """
    if (initial_position is None) == (initial_state is None):
        raise ValueError(
            "Either `initial_position` or `initial_state` must be specified, but not both."
        )
    if initial_state is None:
        initial_state =  inference_algorithm.init(initial_position)

    keys = jax.random.split(rng_key, num_steps)

    @jax.jit
    def one_step(state, rng_key):
        state, info = inference_algorithm.step(rng_key, state)
        return state, (state, info)

    final_state, (state_history, info_history) = jax.lax.scan(
        one_step, initial_state, keys
    )
    return final_state, state_history, info_history

@junpenglao you mentioned that having initial_position in the second arg makes vmapping easier, but if we wanted to vmap over initial positions couldn't we just:

jax.vmap(lambda rng_key, initial_position: run_inference_algorithm(rng_key, inference_algorithm, num_steps, initial_position=initial_position)(rng_keys, initial_positions)

I am also a beginner so I could be wrong; I am trying to learn so apologies in advance! Thanks for all the help :)

blackjax/util.py Show resolved Hide resolved
blackjax/util.py Outdated Show resolved Hide resolved
@junpenglao
Copy link
Member

We should modify the test instead, try: logdensity_fn: lambda x: jnp.ones_like(x)

@PaulScemama PaulScemama changed the title Draft: Resolve 'Functions to run kernels' Resolve 'Functions to run kernels' Dec 3, 2023
@PaulScemama
Copy link
Contributor Author

@junpenglao I think it is ready for review

@junpenglao
Copy link
Member

Could you also update this one:

_, infos = jax.lax.scan(one_step, last_states, keys)

@junpenglao
Copy link
Member

Nice work, thank you!

Copy link

codecov bot commented Dec 6, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (0a84b22) 99.16% compared to head (f0ca07e) 99.16%.
Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #598   +/-   ##
=======================================
  Coverage   99.16%   99.16%           
=======================================
  Files          54       54           
  Lines        2513     2527   +14     
=======================================
+ Hits         2492     2506   +14     
  Misses         21       21           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@junpenglao junpenglao merged commit 435bffd into blackjax-devs:main Dec 6, 2023
7 checks passed
@junpenglao
Copy link
Member

junpenglao commented Dec 6, 2023

Thank you for your contribution @PaulScemama! And congrats on your first PR to Blackjax :-)

Could you also open a PR to change the inference loop usage in sampling-book?

@PaulScemama
Copy link
Contributor Author

@junpenglao thank you! :) Thanks for all the guidance. The library is wonderful and I'm excited to continue contributing. I will open a PR to change the inference loop usage in the sampling book.

junpenglao pushed a commit that referenced this pull request Mar 12, 2024
* Add  function, modify  to account for change

* Revert back quickstart.md

* Add run_inference wrapper to tests/mcmc/sampling; Get rid of arg types for run_inference wrapper

* Get rid of unused imports

* Add run_inference wrapper to tests/benchmark

* Change import style

* Import style for benchmarks test; add wrapper to adaptation test

* Replace 'kernel' variable name with 'inference algorithm' when using wrapper run_inference_algorithm

---------

Co-authored-by: Paul Scemama <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Functions to run kernels
3 participants