-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NumPyro and Other Samplers Tutorial #1842
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Hi @juanitorduz, could we make the tutorial not specific to pathfinder? Maybe add an additional section for other library like flowMC https://github.com/kazewong/flowMC/blob/main/example/dualmoon.py I'm not sure those libraries support api like |
Ok! I will give it a try!
Ok! I need to investigate! Would this be part of the same notebook or a different one? |
ok! I think I made it work with |
ok! @fehiepsi, I think this iteration is ready for review. I suggest including the MCMC kernels in a different PR if that is fine :) |
Shall we also mention a work about https://jax-ml.github.io/bayeux/ ? |
Yeah, mentioning bayeux would be great. It is a really nice library. |
I think adding a new section in the same notebook is better: https://github.com/blackjax-devs/blackjax/blob/441412a09e39f514189be84813f812d95709365c/blackjax/mcmc/nuts.py#L170 |
@fehiepsi What do you think about merging this tutorial as it is and adding the MCMC kernel details later, as I need to get used to it and figure out some details :) ? |
So, based on the class signature and examples, I got to the following class definition: from functools import partial
from numpyro.infer import init_to_uniform
from numpyro.infer.mcmc import MCMCKernel
class PathfinderKernel(MCMCKernel):
def __init__(
self,
model,
init_strategy=init_to_uniform,
):
self._model = model
self._init_strategy = init_strategy
# Set on first call to init
self._potential_fn = None
self._postprocess_fn = None
def _init_state(self, rng_key, model_args, model_kwargs):
param_info, potential_fn, postprocess_fn, *_ = initialize_model(
rng_key,
self._model,
model_args=model_args,
model_kwargs=model_kwargs,
init_strategy=self._init_strategy,
dynamic_args=True,
)
self._potential_fn = potential_fn
self._postprocess_fn = postprocess_fn
return param_info.z
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
return self._init_state(rng_key, model_args, model_kwargs)
def sample(self, state, model_args, model_kwargs):
rng_key = random.PRNGKey(0)
def _logdensity_fn(position, model_args, model_kwargs):
func = self._potential_fn(*model_args, **model_kwargs)
return -func(position)
logdensity_fn = partial(
_logdensity_fn, model_args=model_args, model_kwargs=model_kwargs
)
rng_key, rng_subkey = random.split(rng_key)
pathfinder_state, _ = blackjax.vi.pathfinder.approximate(
rng_key=rng_subkey,
logdensity_fn=logdensity_fn,
initial_position=state,
num_samples=1,
ftol=1e-4,
)
rng_key, rng_subkey = random.split(rng_key)
posterior_samples_pathfinder, _ = blackjax.vi.pathfinder.sample(
rng_key=rng_subkey,
state=pathfinder_state,
num_samples=1,
)
return posterior_samples_pathfinder
pathfinder_kernel = PathfinderKernel(model)
state = pathfinder_kernel.init(rng_key, 10_000, initial_position, (x, y), {})
pathfinder_kernel.sample(state, (x, y), {})
>> {'a': Array([1.03468165], dtype=float64),
'b': Array([0.68710244], dtype=float64),
'sigma': Array([-0.66870928], dtype=float64)} It basically implements |
Yup, let's do it. I dont think we need to use pathfinder. Any other third party api that offers init and update apis is already good for a demonstration. |
View / edit / reply to this conversation on ReviewNB fehiepsi commented on 2024-08-07T21:13:59Z It is surprised to me that we didn't expose |
Related to #1485, see #1485 (comment)
Tutorial showing how to use numpyro with other libraries via
initialize_model
.Port example notebook from https://juanitorduz.github.io/numpyro_pathfinder/