From 7df4d1b932b4c7f2ce205a1369a0e5bbdff8ae5e Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Wed, 18 Jan 2023 16:18:11 -0300 Subject: [PATCH] Fix the Tempered SMC notebook --- docs/examples/TemperedSMC.md | 48 ++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/docs/examples/TemperedSMC.md b/docs/examples/TemperedSMC.md index a3349c39b..d168b27b7 100644 --- a/docs/examples/TemperedSMC.md +++ b/docs/examples/TemperedSMC.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.14.0 + jupytext_version: 1.14.4 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -27,7 +27,7 @@ is not well calibrated (too small step size, etc) like in the example below. ## Imports -```{code-cell} python +```{code-cell} ipython3 import jax import jax.numpy as jnp import matplotlib.pyplot as plt @@ -58,7 +58,7 @@ $$ This corresponds to the following distribution. We plot the resulting tempered density for 5 different values of $\lambda_k$ : from $\lambda_k =1$ which correponds to the original density to $\lambda_k=0$. The lower the value of $\lambda_k$ the easier it is for the sampler to jump between the modes of the posterior density. -```{code-cell} python +```{code-cell} ipython3 def V(x): return 5 * jnp.square(jnp.sum(x**2) - 1) @@ -83,8 +83,7 @@ normalizing_factor = jnp.sum(density, axis=1, keepdims=True) * ( density /= normalizing_factor ``` - -```{code-cell} python +```{code-cell} ipython3 :tags: [hide-input] fig, ax = plt.subplots(figsize=(12, 8)) @@ -92,7 +91,7 @@ ax.plot(linspace.squeeze(), density.T) ax.legend(list(lambdas)) ``` -```{code-cell} python +```{code-cell} ipython3 def inference_loop(rng_key, mcmc_kernel, initial_state, num_samples): @jax.jit def one_step(state, k): @@ -117,7 +116,7 @@ n_samples = 10_000 We first try to sample from the posterior density using an HMC kernel. -```{code-cell} python +```{code-cell} ipython3 %%time key = jax.random.PRNGKey(42) @@ -131,7 +130,7 @@ hmc_state = hmc.init(jnp.ones((1,))) hmc_samples = inference_loop(key, hmc.step, hmc_state, n_samples) ``` -```{code-cell} python +```{code-cell} ipython3 :tags: [hide-input] samples = np.array(hmc_samples.position[:, 0]) @@ -143,7 +142,7 @@ _ = plt.plot(linspace.squeeze(), density[-1]) We now use a NUTS kernel. -```{code-cell} python +```{code-cell} ipython3 %%time nuts_parameters = dict(step_size=1e-4, inverse_mass_matrix=inv_mass_matrix) @@ -153,7 +152,7 @@ nuts_state = nuts.init(jnp.ones((1,))) nuts_samples = inference_loop(key, nuts.step, nuts_state, n_samples) ``` -```{code-cell} python +```{code-cell} ipython3 :tags: [hide-input] samples = np.array(nuts_samples.position[:, 0]) @@ -165,7 +164,7 @@ _ = plt.plot(linspace.squeeze(), density[-1]) We now use the adaptive tempered SMC algorithm with an HMC kernel. We only take one HMC step before resampling. The algorithm is run until $\lambda_k$ crosses the $\lambda_k = 1$ limit. -```{code-cell} python +```{code-cell} ipython3 def smc_inference_loop(rng_key, smc_kernel, initial_state): """Run the temepered SMC algorithm. @@ -191,7 +190,7 @@ def smc_inference_loop(rng_key, smc_kernel, initial_state): return n_iter, final_state ``` -```{code-cell} python +```{code-cell} ipython3 %%time loglikelihood = lambda x: -V(x) @@ -203,11 +202,12 @@ hmc_parameters = dict( tempered = blackjax.adaptive_tempered_smc( prior_log_prob, loglikelihood, - blackjax.hmc, + blackjax.hmc.kernel(), + blackjax.hmc.init, hmc_parameters, resampling.systematic, 0.5, - mcmc_iter=1, + num_mcmc_steps=1, ) initial_smc_state = jax.random.multivariate_normal( @@ -219,7 +219,7 @@ n_iter, smc_samples = smc_inference_loop(key, tempered.step, initial_smc_state) print("Number of steps in the adaptive algorithm: ", n_iter.item()) ``` -```{code-cell} python +```{code-cell} ipython3 :tags: [hide-input] samples = np.array(smc_samples.particles[:, 0]) @@ -235,7 +235,7 @@ We consider a prior distribution $p_0(x) = \mathcal{N}(x \mid 0_2, 2 I_2)$ and w We plot the resulting tempered density for 5 different values of $\lambda_k$: from $\lambda_k =1$ which correponds to the original density to $\lambda_k=0$. The lower the value of $\lambda_k$ the easier it is to sampler from the posterior log-density. -```{code-cell} python +```{code-cell} ipython3 def prior_log_prob(x): d = x.shape[0] return multivariate_normal.logpdf(x, jnp.zeros((d,)), 2 * jnp.eye(d)) @@ -259,7 +259,7 @@ normalizing_factor = jnp.sum(density, axis=1, keepdims=True) * ( density /= normalizing_factor ``` -```{code-cell} python +```{code-cell} ipython3 :tags: [hide-input] fig, ax = plt.subplots(figsize=(12, 8)) @@ -267,7 +267,7 @@ ax.semilogy(linspace.squeeze(), density.T) ax.legend(list(lambdas)) ``` -```{code-cell} python +```{code-cell} ipython3 def inference_loop(rng_key, mcmc_kernel, initial_state, num_samples): def one_step(state, k): state, _ = mcmc_kernel(k, state) @@ -287,7 +287,7 @@ n_samples = 1_000 We first try to sample from the posterior density using an HMC kernel. -```{code-cell} python +```{code-cell} ipython3 %%time key = jax.random.PRNGKey(42) @@ -303,7 +303,7 @@ hmc_state = hmc.init(jnp.ones((1,))) hmc_samples = inference_loop(key, hmc.step, hmc_state, n_samples) ``` -```{code-cell} python +```{code-cell} ipython3 :tags: [hide-input] samples = np.array(hmc_samples.position[:, 0]) @@ -316,7 +316,7 @@ _ = plt.yscale("log") We do the same using a NUTS kernel. -```{code-cell} python +```{code-cell} ipython3 %%time nuts_parameters = dict(step_size=1e-2, inverse_mass_matrix=inv_mass_matrix) @@ -326,7 +326,7 @@ nuts_state = nuts.init(jnp.ones((1,))) nuts_samples = inference_loop(key, nuts.step, nuts_state, n_samples) ``` -```{code-cell} python +```{code-cell} ipython3 :tags: [hide-input] samples = np.array(nuts_samples.position[:, 0]) @@ -340,7 +340,7 @@ _ = plt.yscale("log") We now use the adaptive tempered SMC algorithm with an HMC kernel. We only take one HMC step before resampling. The algorithm is run until $\lambda_k$ crosses the $\lambda_k = 1$ limit. We correct the bias introduced by the (arbitrary) prior. -```{code-cell} python +```{code-cell} ipython3 %%time loglikelihood = lambda x: -V(x) @@ -368,7 +368,7 @@ n_iter, smc_samples = smc_inference_loop(key, tempered.step, initial_smc_state) print("Number of steps in the adaptive algorithm: ", n_iter.item()) ``` -```{code-cell} python +```{code-cell} ipython3 :tags: [hide-input] samples = np.array(smc_samples.particles[:, 0])