Skip to content

Commit

Permalink
Fix the Tempered SMC notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
ciguaran authored and rlouf committed Jan 19, 2023
1 parent e21fa82 commit 7df4d1b
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions docs/examples/TemperedSMC.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.0
jupytext_version: 1.14.4
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand All @@ -27,7 +27,7 @@ is not well calibrated (too small step size, etc) like in the example below.

## Imports

```{code-cell} python
```{code-cell} ipython3
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -58,7 +58,7 @@ $$

This corresponds to the following distribution. We plot the resulting tempered density for 5 different values of $\lambda_k$ : from $\lambda_k =1$ which correponds to the original density to $\lambda_k=0$. The lower the value of $\lambda_k$ the easier it is for the sampler to jump between the modes of the posterior density.

```{code-cell} python
```{code-cell} ipython3
def V(x):
return 5 * jnp.square(jnp.sum(x**2) - 1)
Expand All @@ -83,16 +83,15 @@ normalizing_factor = jnp.sum(density, axis=1, keepdims=True) * (
density /= normalizing_factor
```


```{code-cell} python
```{code-cell} ipython3
:tags: [hide-input]
fig, ax = plt.subplots(figsize=(12, 8))
ax.plot(linspace.squeeze(), density.T)
ax.legend(list(lambdas))
```

```{code-cell} python
```{code-cell} ipython3
def inference_loop(rng_key, mcmc_kernel, initial_state, num_samples):
@jax.jit
def one_step(state, k):
Expand All @@ -117,7 +116,7 @@ n_samples = 10_000

We first try to sample from the posterior density using an HMC kernel.

```{code-cell} python
```{code-cell} ipython3
%%time
key = jax.random.PRNGKey(42)
Expand All @@ -131,7 +130,7 @@ hmc_state = hmc.init(jnp.ones((1,)))
hmc_samples = inference_loop(key, hmc.step, hmc_state, n_samples)
```

```{code-cell} python
```{code-cell} ipython3
:tags: [hide-input]
samples = np.array(hmc_samples.position[:, 0])
Expand All @@ -143,7 +142,7 @@ _ = plt.plot(linspace.squeeze(), density[-1])

We now use a NUTS kernel.

```{code-cell} python
```{code-cell} ipython3
%%time
nuts_parameters = dict(step_size=1e-4, inverse_mass_matrix=inv_mass_matrix)
Expand All @@ -153,7 +152,7 @@ nuts_state = nuts.init(jnp.ones((1,)))
nuts_samples = inference_loop(key, nuts.step, nuts_state, n_samples)
```

```{code-cell} python
```{code-cell} ipython3
:tags: [hide-input]
samples = np.array(nuts_samples.position[:, 0])
Expand All @@ -165,7 +164,7 @@ _ = plt.plot(linspace.squeeze(), density[-1])

We now use the adaptive tempered SMC algorithm with an HMC kernel. We only take one HMC step before resampling. The algorithm is run until $\lambda_k$ crosses the $\lambda_k = 1$ limit.

```{code-cell} python
```{code-cell} ipython3
def smc_inference_loop(rng_key, smc_kernel, initial_state):
"""Run the temepered SMC algorithm.
Expand All @@ -191,7 +190,7 @@ def smc_inference_loop(rng_key, smc_kernel, initial_state):
return n_iter, final_state
```

```{code-cell} python
```{code-cell} ipython3
%%time
loglikelihood = lambda x: -V(x)
Expand All @@ -203,11 +202,12 @@ hmc_parameters = dict(
tempered = blackjax.adaptive_tempered_smc(
prior_log_prob,
loglikelihood,
blackjax.hmc,
blackjax.hmc.kernel(),
blackjax.hmc.init,
hmc_parameters,
resampling.systematic,
0.5,
mcmc_iter=1,
num_mcmc_steps=1,
)
initial_smc_state = jax.random.multivariate_normal(
Expand All @@ -219,7 +219,7 @@ n_iter, smc_samples = smc_inference_loop(key, tempered.step, initial_smc_state)
print("Number of steps in the adaptive algorithm: ", n_iter.item())
```

```{code-cell} python
```{code-cell} ipython3
:tags: [hide-input]
samples = np.array(smc_samples.particles[:, 0])
Expand All @@ -235,7 +235,7 @@ We consider a prior distribution $p_0(x) = \mathcal{N}(x \mid 0_2, 2 I_2)$ and w

We plot the resulting tempered density for 5 different values of $\lambda_k$: from $\lambda_k =1$ which correponds to the original density to $\lambda_k=0$. The lower the value of $\lambda_k$ the easier it is to sampler from the posterior log-density.

```{code-cell} python
```{code-cell} ipython3
def prior_log_prob(x):
d = x.shape[0]
return multivariate_normal.logpdf(x, jnp.zeros((d,)), 2 * jnp.eye(d))
Expand All @@ -259,15 +259,15 @@ normalizing_factor = jnp.sum(density, axis=1, keepdims=True) * (
density /= normalizing_factor
```

```{code-cell} python
```{code-cell} ipython3
:tags: [hide-input]
fig, ax = plt.subplots(figsize=(12, 8))
ax.semilogy(linspace.squeeze(), density.T)
ax.legend(list(lambdas))
```

```{code-cell} python
```{code-cell} ipython3
def inference_loop(rng_key, mcmc_kernel, initial_state, num_samples):
def one_step(state, k):
state, _ = mcmc_kernel(k, state)
Expand All @@ -287,7 +287,7 @@ n_samples = 1_000

We first try to sample from the posterior density using an HMC kernel.

```{code-cell} python
```{code-cell} ipython3
%%time
key = jax.random.PRNGKey(42)
Expand All @@ -303,7 +303,7 @@ hmc_state = hmc.init(jnp.ones((1,)))
hmc_samples = inference_loop(key, hmc.step, hmc_state, n_samples)
```

```{code-cell} python
```{code-cell} ipython3
:tags: [hide-input]
samples = np.array(hmc_samples.position[:, 0])
Expand All @@ -316,7 +316,7 @@ _ = plt.yscale("log")

We do the same using a NUTS kernel.

```{code-cell} python
```{code-cell} ipython3
%%time
nuts_parameters = dict(step_size=1e-2, inverse_mass_matrix=inv_mass_matrix)
Expand All @@ -326,7 +326,7 @@ nuts_state = nuts.init(jnp.ones((1,)))
nuts_samples = inference_loop(key, nuts.step, nuts_state, n_samples)
```

```{code-cell} python
```{code-cell} ipython3
:tags: [hide-input]
samples = np.array(nuts_samples.position[:, 0])
Expand All @@ -340,7 +340,7 @@ _ = plt.yscale("log")
We now use the adaptive tempered SMC algorithm with an HMC kernel. We only take one HMC step before resampling. The algorithm is run until $\lambda_k$ crosses the $\lambda_k = 1$ limit.
We correct the bias introduced by the (arbitrary) prior.

```{code-cell} python
```{code-cell} ipython3
%%time
loglikelihood = lambda x: -V(x)
Expand Down Expand Up @@ -368,7 +368,7 @@ n_iter, smc_samples = smc_inference_loop(key, tempered.step, initial_smc_state)
print("Number of steps in the adaptive algorithm: ", n_iter.item())
```

```{code-cell} python
```{code-cell} ipython3
:tags: [hide-input]
samples = np.array(smc_samples.particles[:, 0])
Expand Down

0 comments on commit 7df4d1b

Please sign in to comment.