Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NumPyro and Other Samplers Tutorial #1842

Merged
merged 9 commits into from
Aug 9, 2024

Conversation

juanitorduz
Copy link
Contributor

@juanitorduz juanitorduz commented Aug 3, 2024

Related to #1485, see #1485 (comment)

Tutorial showing how to use numpyro with other libraries via initialize_model.

Port example notebook from https://juanitorduz.github.io/numpyro_pathfinder/

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@fehiepsi
Copy link
Member

fehiepsi commented Aug 4, 2024

Hi @juanitorduz, could we make the tutorial not specific to pathfinder? Maybe add an additional section for other library like flowMC https://github.com/kazewong/flowMC/blob/main/example/dualmoon.py

I'm not sure those libraries support api like init update. If they do, it would be nice to show how to construct custom MCMC kernel on the top of them - this way we can reuse api for parallel, vectorized,... that numpyro.infer.MCMC offers.

@juanitorduz
Copy link
Contributor Author

Maybe add an additional section for other library like flowMC https://github.com/kazewong/flowMC/blob/main/example/dualmoon.py

Ok! I will give it a try!

I'm not sure those libraries support api like init update. If they do, it would be nice to show how to construct custom MCMC kernel on the top of them - this way we can reuse api for parallel, vectorized,... that numpyro.infer.MCMC offers.

Ok! I need to investigate! Would this be part of the same notebook or a different one?

@juanitorduz
Copy link
Contributor Author

ok! I think I made it work with flowMC in 49f12c6 (probably not the most elegant way so feedback is welcome). I will clean the notebook and add text next.

@juanitorduz
Copy link
Contributor Author

ok! @fehiepsi, I think this iteration is ready for review. I suggest including the MCMC kernels in a different PR if that is fine :)

@juanitorduz juanitorduz changed the title Pathfinder example notebook NumPyro and Other Samplers Tutorial Aug 5, 2024
@juanitorduz
Copy link
Contributor Author

Shall we also mention a work about https://jax-ml.github.io/bayeux/ ?

@fehiepsi
Copy link
Member

fehiepsi commented Aug 5, 2024

Yeah, mentioning bayeux would be great. It is a really nice library.

@fehiepsi
Copy link
Member

fehiepsi commented Aug 5, 2024

Ok! I need to investigate! Would this be part of the same notebook or a different one?

I think adding a new section in the same notebook is better: https://github.com/blackjax-devs/blackjax/blob/441412a09e39f514189be84813f812d95709365c/blackjax/mcmc/nuts.py#L170

@juanitorduz
Copy link
Contributor Author

@fehiepsi What do you think about merging this tutorial as it is and adding the MCMC kernel details later, as I need to get used to it and figure out some details :) ?

@juanitorduz
Copy link
Contributor Author

So, based on the class signature and examples, I got to the following class definition:

from functools import partial

from numpyro.infer import init_to_uniform
from numpyro.infer.mcmc import MCMCKernel


class PathfinderKernel(MCMCKernel):
    def __init__(
        self,
        model,
        init_strategy=init_to_uniform,
    ):
        self._model = model
        self._init_strategy = init_strategy
        # Set on first call to init
        self._potential_fn = None
        self._postprocess_fn = None

    def _init_state(self, rng_key, model_args, model_kwargs):
        param_info, potential_fn, postprocess_fn, *_ = initialize_model(
            rng_key,
            self._model,
            model_args=model_args,
            model_kwargs=model_kwargs,
            init_strategy=self._init_strategy,
            dynamic_args=True,
        )
        self._potential_fn = potential_fn
        self._postprocess_fn = postprocess_fn
        return param_info.z

    def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
        return self._init_state(rng_key, model_args, model_kwargs)

    def sample(self, state, model_args, model_kwargs):
        rng_key = random.PRNGKey(0)

        def _logdensity_fn(position, model_args, model_kwargs):
            func = self._potential_fn(*model_args, **model_kwargs)
            return -func(position)

        logdensity_fn = partial(
            _logdensity_fn, model_args=model_args, model_kwargs=model_kwargs
        )

        rng_key, rng_subkey = random.split(rng_key)
        pathfinder_state, _ = blackjax.vi.pathfinder.approximate(
            rng_key=rng_subkey,
            logdensity_fn=logdensity_fn,
            initial_position=state,
            num_samples=1,
            ftol=1e-4,
        )

        rng_key, rng_subkey = random.split(rng_key)
        posterior_samples_pathfinder, _ = blackjax.vi.pathfinder.sample(
            rng_key=rng_subkey,
            state=pathfinder_state,
            num_samples=1,
        )

        return posterior_samples_pathfinder

pathfinder_kernel = PathfinderKernel(model)

state = pathfinder_kernel.init(rng_key, 10_000, initial_position, (x, y), {})
pathfinder_kernel.sample(state, (x, y), {})

>> {'a': Array([1.03468165], dtype=float64),
 'b': Array([0.68710244], dtype=float64),
 'sigma': Array([-0.66870928], dtype=float64)}

It basically implements init and sample. How does it look? How can we then use this to run full inference (via a scan loop?)

@fehiepsi
Copy link
Member

fehiepsi commented Aug 7, 2024

Yup, let's do it. I dont think we need to use pathfinder. Any other third party api that offers init and update apis is already good for a demonstration.

Copy link

View / edit / reply to this conversation on ReviewNB

fehiepsi commented on 2024-08-07T21:13:59Z
----------------------------------------------------------------

It is surprised to me that we didn't expose initialize_model in the doc...


docs/source/index.rst Outdated Show resolved Hide resolved
@juanitorduz juanitorduz requested a review from fehiepsi August 8, 2024 08:45
@fehiepsi fehiepsi merged commit b6e4629 into pyro-ppl:master Aug 9, 2024
4 checks passed
@juanitorduz juanitorduz deleted the pathfinder_example branch August 9, 2024 21:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants