From becd2d20b987cb496f873d84fa4f6265f17d4483 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 27 Oct 2022 18:08:19 +0200 Subject: [PATCH] Small updates to the change of variable example --- examples/change_of_variable_hmc.md | 34 ++++++++++++++++++------------ 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/examples/change_of_variable_hmc.md b/examples/change_of_variable_hmc.md index a50298368..9cf277074 100644 --- a/examples/change_of_variable_hmc.md +++ b/examples/change_of_variable_hmc.md @@ -6,7 +6,7 @@ jupytext: format_version: 0.13 jupytext_version: 1.14.1 kernelspec: - display_name: Python 3.9.7 ('blackjax') + display_name: Python 3 (ipykernel) language: python name: python3 mystnb: @@ -28,7 +28,6 @@ In particular we use following binomial hierarchical model where $y_{j}$ and $N_ \end{align} ``` - ```{code-cell} ipython3 :tags: [hide-cell] @@ -37,7 +36,6 @@ import jax import jax.numpy as jnp import matplotlib.pyplot as plt import pandas as pd -import seaborn as sns pd.set_option("display.max_rows", 80) @@ -217,21 +215,29 @@ n_rat_tumors = len(group_size) ```{code-cell} ipython3 :tags: [hide-input] -plt.figure(figsize=(12, 3)) -plt.bar(range(n_rat_tumors), n_of_positives) +fig = plt.figure(figsize=(12, 3)) +ax = fig.add_subplot(111) +ax.bar(range(n_rat_tumors), n_of_positives) + +ax.set_xlabel("tumor type", fontsize=12) +ax.spines["top"].set_visible(False) +ax.spines["right"].set_visible(False) + plt.title("No. of positives for each tumor type", fontsize=14) -plt.xlabel("tumor type", fontsize=12) -sns.despine() ``` ```{code-cell} ipython3 :tags: [hide-input] -plt.figure(figsize=(14, 4)) -plt.bar(range(n_rat_tumors), group_size) +fig = plt.figure(figsize=(14, 4)) +ax = fig.add_subplot(111) + +ax.bar(range(n_rat_tumors), group_size) plt.title("Group size for each tumor type", fontsize=14) -plt.xlabel("tumor type", fontsize=12) -sns.despine() + +ax.set_xlabel("tumor type", fontsize=12) +ax.spines["top"].set_visible(False) +ax.spines["right"].set_visible(False) ``` ## Posterior Sampling @@ -298,7 +304,7 @@ def call_warmup(seed, param): initial_states, _, tuned_params = warmup.run(seed, param, 1000) return initial_states, tuned_params -initial_states, tuned_params = call_warmup(keys, init_params) +initial_states, tuned_params = jax.jit(call_warmup)(keys, init_params) ``` Now we write inference loop for multiple chains @@ -312,7 +318,6 @@ def inference_loop_multiple_chains( def kernel(key, state, **params): return step_fn(key, state, log_prob_fn, **params) - @jax.jit def one_step(states, rng_key): keys = jax.random.split(rng_key, num_chains) states, infos = jax.vmap(kernel)(keys, states, **tuned_params) @@ -428,7 +433,7 @@ def joint_logprob_change_of_var(params): return logprob_ab + logprob_thetas + logprob_y + log_det_jacob ``` -except change of variable in `joint_logprob()` function, everthing will remain same +except for the change of variable in `joint_logprob()` function, everthing will remain same ```{code-cell} ipython3 rng_key = jax.random.PRNGKey(0) @@ -555,6 +560,7 @@ init_key, warmup_key = jax.random.split(rng_key, 2) init_params = bijectors.inverse(pinned.sample_unpinned(n_chains, seed=init_key)) keys = jax.random.split(warmup_key, n_chains) + @jax.vmap def call_warmup(seed, param): initial_states, _, tuned_params = warmup.run(seed, param, 1000)