diff --git a/examples/Pathfinder.md b/examples/Pathfinder.md index 0dc9e9630..6b9fb7531 100644 --- a/examples/Pathfinder.md +++ b/examples/Pathfinder.md @@ -4,9 +4,9 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.14.0 + jupytext_version: 1.14.1 kernelspec: - display_name: Python 3.9.7 ('blackjax') + display_name: Python 3 (ipykernel) language: python name: python3 --- @@ -115,7 +115,8 @@ To help understand the approximations that pathfinder evaluates during its run, ```{code-cell} ipython3 rng_key = random.PRNGKey(314) w0 = random.multivariate_normal(rng_key, 2.0 + jnp.zeros(M), jnp.eye(M)) -path = blackjax.vi.pathfinder.init(rng_key, logprob_fn, w0, return_path=True, ftol=1e-4) +_, info = blackjax.vi.pathfinder.approximate(rng_key, logprob_fn, w0, ftol=1e-4) +path = info.path ``` ```{code-cell} ipython3 @@ -156,7 +157,7 @@ for i, ax in zip(range(1, steps + 1), axs.flatten()): ax.contour(x_, y_, logp_, levels=levels_) state = jax.tree_map(lambda x: x[i], path) - sample_state, _ = blackjax.vi.pathfinder.sample_from_state(rng_key, state, 10_000) + sample_state, _ = blackjax.vi.pathfinder.sample(rng_key, state, 10_000) position_path = path.position[: i + 1] ax.plot( position_path[:, 0], @@ -176,28 +177,15 @@ fig.show() Pathfinder can be used as a variational inference method, using its kernel API: ```{code-cell} ipython3 -pathfinder = blackjax.kernels.pathfinder(rng_key, logprob_fn, ftol=1e-4) -state = pathfinder.init(w0) +pathfinder = blackjax.kernels.pathfinder(logprob_fn) +state, _ = pathfinder.approximate(rng_key, w0, ftol=1e-4) ``` -Since `blackjax` does not provide an inference loop we need to implement one ourselves: - -```{code-cell} ipython3 -def inference_loop(rng_key, kernel, initial_state, num_samples): - @jax.jit - def one_step(state, rng_key): - state, info = kernel(rng_key, state) - return state, (state, info) - - keys = jax.random.split(rng_key, num_samples) - return jax.lax.scan(one_step, initial_state, keys) -``` - -We can now run the inference: +We can now get samples from the approximation: ```{code-cell} ipython3 _, rng_key = random.split(rng_key) -_, (_, samples) = inference_loop(rng_key, pathfinder.step, state, 5_000) +samples, _ = pathfinder.sample(rng_key, state, 5_000) ``` And display the trace: @@ -220,13 +208,24 @@ Please note that pathfinder is implemented as follows: hence it makes sense to `jit` the `init` function and then use the `blackjax.vi.pathfinder.sample_from_state` helper function instead of implementing the inference loop: ```{code-cell} ipython3 -state = jax.jit(pathfinder.init)(w0) -samples, _ = blackjax.vi.pathfinder.sample_from_state(rng_key, state, 5_000) +%%time + +state, _ = jax.jit(pathfinder.approximate)(rng_key, w0) +samples, _ = pathfinder.sample(rng_key, state, 5_000) ``` Quick comparison against `rmh` kernel: ```{code-cell} ipython3 +def inference_loop(rng_key, kernel, initial_state, num_samples): + @jax.jit + def one_step(state, rng_key): + state, info = kernel(rng_key, state) + return state, (state, info) + + keys = jax.random.split(rng_key, num_samples) + return jax.lax.scan(one_step, initial_state, keys) + rmh = blackjax.kernels.rmh(logprob_fn, sigma=jnp.ones(M) * 0.7) state_rmh = rmh.init(w0) _, (samples_rmh, _) = inference_loop(rng_key, rmh.step, state_rmh, 5_000) @@ -267,7 +266,7 @@ This estimation of the inverse mass matrix, coupled with Nesterov's dual averagi This scheme is implemented in `blackjax.kernel.pathfinder_adaptation` function: ```{code-cell} ipython3 -adapt = blackjax.kernels.pathfinder_adaptation(blackjax.nuts, logprob_fn) +adapt = blackjax.kernels.pathfinder_adaptation(jax.jit(blackjax.nuts), logprob_fn) state, kernel, info = adapt.run(rng_key, w0, 400) ```