From 6481f693bb9a1100ea972f8551879ef358ddc3a3 Mon Sep 17 00:00:00 2001 From: junpenglao Date: Sat, 17 Dec 2022 19:36:34 +0100 Subject: [PATCH] Minor doc clean up example using oyrx is currently broken (https://github.com/jax-ml/oryx/issues/25) --- blackjax/kernels.py | 4 ++-- docs/howto_use_ppl.rst | 1 + examples/SGMCMC.md | 2 +- examples/howto_use_numpyro.md | 2 +- examples/howto_use_tfp.md | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/blackjax/kernels.py b/blackjax/kernels.py index 8ed4e9f78..508c9c0ed 100644 --- a/blackjax/kernels.py +++ b/blackjax/kernels.py @@ -418,13 +418,13 @@ class mgrad_gaussian: .. code:: mgrad_gaussian = blackjax.mgrad_gaussian(f, C, use_inverse=False, mean=m) - state = latent_gaussian.init(zeros) # Starting at the mean of the prior + state = mgrad_gaussian.init(zeros) # Starting at the mean of the prior new_state, info = mgrad_gaussian.step(rng_key, state, delta) We can JIT-compile the step function for better performance .. code:: - step = jax.jit(latent_gaussian.step) + step = jax.jit(mgrad_gaussian.step) new_state, info = step(rng_key, state, delta) Parameters diff --git a/docs/howto_use_ppl.rst b/docs/howto_use_ppl.rst index c318ffc8a..2c0a96300 100644 --- a/docs/howto_use_ppl.rst +++ b/docs/howto_use_ppl.rst @@ -7,4 +7,5 @@ Use the model I built with X? examples/howto_use_aesara.md examples/howto_use_numpyro.md examples/howto_use_oryx.md + examples/howto_use_pymc.md examples/howto_use_tfp.md diff --git a/examples/SGMCMC.md b/examples/SGMCMC.md index 409aeef41..996fbc266 100644 --- a/examples/SGMCMC.md +++ b/examples/SGMCMC.md @@ -11,7 +11,7 @@ kernelspec: name: python3 file_format: mystnb mystnb: - execution_timeout: 200 + execution_timeout: 300 merge_streams: true --- diff --git a/examples/howto_use_numpyro.md b/examples/howto_use_numpyro.md index 71a683f21..eab440541 100644 --- a/examples/howto_use_numpyro.md +++ b/examples/howto_use_numpyro.md @@ -163,7 +163,7 @@ fig.set_size_inches(12, 10) for i in range(J): axes[i][0].plot(samples["theta_base"][:, i]) axes[i][0].title.set_text(f"School {i} relative treatment effect chain") - sns.kdeplot(samples["theta_base"][:, i], ax=axes[i][1], shade=True) + sns.kdeplot(samples["theta_base"][:, i], ax=axes[i][1], fill=True) axes[i][1].title.set_text(f"School {i} relative treatment effect distribution") axes[J - 1][0].set_xlabel("Iteration") axes[J - 1][1].set_xlabel("School effect") diff --git a/examples/howto_use_tfp.md b/examples/howto_use_tfp.md index dd6990ffe..e0e446986 100644 --- a/examples/howto_use_tfp.md +++ b/examples/howto_use_tfp.md @@ -176,7 +176,7 @@ fig.set_size_inches(12, 10) for i in range(num_schools): axes[i][0].plot(school_effects_samples[:, i]) axes[i][0].title.set_text(f"School {i} treatment effect chain") - sns.kdeplot(school_effects_samples[:, i], ax=axes[i][1], shade=True) + sns.kdeplot(school_effects_samples[:, i], ax=axes[i][1], fill=True) axes[i][1].title.set_text(f"School {i} treatment effect distribution") axes[num_schools - 1][0].set_xlabel("Iteration") axes[num_schools - 1][1].set_xlabel("School effect")