Skip to content

Commit

Permalink
Update the Pathfinder example
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 12, 2022
1 parent d2513e1 commit 5b6e262
Showing 1 changed file with 23 additions and 24 deletions.
47 changes: 23 additions & 24 deletions examples/Pathfinder.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.0
jupytext_version: 1.14.1
kernelspec:
display_name: Python 3.9.7 ('blackjax')
display_name: Python 3 (ipykernel)
language: python
name: python3
---
Expand Down Expand Up @@ -115,7 +115,8 @@ To help understand the approximations that pathfinder evaluates during its run,
```{code-cell} ipython3
rng_key = random.PRNGKey(314)
w0 = random.multivariate_normal(rng_key, 2.0 + jnp.zeros(M), jnp.eye(M))
path = blackjax.vi.pathfinder.init(rng_key, logprob_fn, w0, return_path=True, ftol=1e-4)
_, info = blackjax.vi.pathfinder.approximate(rng_key, logprob_fn, w0, ftol=1e-4)
path = info.path
```

```{code-cell} ipython3
Expand Down Expand Up @@ -156,7 +157,7 @@ for i, ax in zip(range(1, steps + 1), axs.flatten()):
ax.contour(x_, y_, logp_, levels=levels_)
state = jax.tree_map(lambda x: x[i], path)
sample_state, _ = blackjax.vi.pathfinder.sample_from_state(rng_key, state, 10_000)
sample_state, _ = blackjax.vi.pathfinder.sample(rng_key, state, 10_000)
position_path = path.position[: i + 1]
ax.plot(
position_path[:, 0],
Expand All @@ -176,28 +177,15 @@ fig.show()
Pathfinder can be used as a variational inference method, using its kernel API:

```{code-cell} ipython3
pathfinder = blackjax.kernels.pathfinder(rng_key, logprob_fn, ftol=1e-4)
state = pathfinder.init(w0)
pathfinder = blackjax.kernels.pathfinder(logprob_fn)
state, _ = pathfinder.approximate(rng_key, w0, ftol=1e-4)
```

Since `blackjax` does not provide an inference loop we need to implement one ourselves:

```{code-cell} ipython3
def inference_loop(rng_key, kernel, initial_state, num_samples):
@jax.jit
def one_step(state, rng_key):
state, info = kernel(rng_key, state)
return state, (state, info)
keys = jax.random.split(rng_key, num_samples)
return jax.lax.scan(one_step, initial_state, keys)
```

We can now run the inference:
We can now get samples from the approximation:

```{code-cell} ipython3
_, rng_key = random.split(rng_key)
_, (_, samples) = inference_loop(rng_key, pathfinder.step, state, 5_000)
samples, _ = pathfinder.sample(rng_key, state, 5_000)
```

And display the trace:
Expand All @@ -220,13 +208,24 @@ Please note that pathfinder is implemented as follows:
hence it makes sense to `jit` the `init` function and then use the `blackjax.vi.pathfinder.sample_from_state` helper function instead of implementing the inference loop:

```{code-cell} ipython3
state = jax.jit(pathfinder.init)(w0)
samples, _ = blackjax.vi.pathfinder.sample_from_state(rng_key, state, 5_000)
%%time
state, _ = jax.jit(pathfinder.approximate)(rng_key, w0)
samples, _ = pathfinder.sample(rng_key, state, 5_000)
```

Quick comparison against `rmh` kernel:

```{code-cell} ipython3
def inference_loop(rng_key, kernel, initial_state, num_samples):
@jax.jit
def one_step(state, rng_key):
state, info = kernel(rng_key, state)
return state, (state, info)
keys = jax.random.split(rng_key, num_samples)
return jax.lax.scan(one_step, initial_state, keys)
rmh = blackjax.kernels.rmh(logprob_fn, sigma=jnp.ones(M) * 0.7)
state_rmh = rmh.init(w0)
_, (samples_rmh, _) = inference_loop(rng_key, rmh.step, state_rmh, 5_000)
Expand Down Expand Up @@ -267,7 +266,7 @@ This estimation of the inverse mass matrix, coupled with Nesterov's dual averagi
This scheme is implemented in `blackjax.kernel.pathfinder_adaptation` function:

```{code-cell} ipython3
adapt = blackjax.kernels.pathfinder_adaptation(blackjax.nuts, logprob_fn)
adapt = blackjax.kernels.pathfinder_adaptation(jax.jit(blackjax.nuts), logprob_fn)
state, kernel, info = adapt.run(rng_key, w0, 400)
```

Expand Down

0 comments on commit 5b6e262

Please sign in to comment.