Skip to content

Commit

Permalink
fix tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
damonbayer committed Jun 14, 2024
1 parent d0ca000 commit 6cd0583
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions model/docs/example_with_datasets.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ axs[0].plot(sim_data.Rt)
axs[0].set_ylabel("Rt")
# Infections plot
axs[1].plot(sim_data.sampled_observed_hosp_admissions)
axs[1].plot(sim_data.observed_hosp_admissions)
axs[1].set_ylabel("Infections")
axs[1].set_yscale("log")
Expand All @@ -236,7 +236,7 @@ import jax
hosp_model.run(
num_samples=2000,
num_warmup=2000,
observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(),
data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(),
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
)
Expand Down Expand Up @@ -281,7 +281,7 @@ dat_w_padding = np.pad(
hosp_model.run(
num_samples=2000,
num_warmup=2000,
observed_hosp_admissions=dat_w_padding,
data_observed_hosp_admissions=dat_w_padding,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
padding=days_to_impute, # Padding the model
Expand Down Expand Up @@ -382,7 +382,7 @@ Running the model (with the same padding as before):
hosp_model_weekday.run(
num_samples=2000,
num_warmup=2000,
observed_hosp_admissions=dat_w_padding,
data_observed_hosp_admissions=dat_w_padding,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
padding=days_to_impute,
Expand Down
4 changes: 2 additions & 2 deletions model/docs/getting_started.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ axs[0].plot(sim_data.Rt)
axs[0].set_ylabel("Rt")
# Infections plot
axs[1].plot(sim_data.sampled_observed_infections)
axs[1].plot(sim_data.observed_infections)
axs[1].set_ylabel("Infections")
fig.suptitle("Basic renewal model")
Expand All @@ -190,7 +190,7 @@ import jax
model1.run(
num_warmup=2000,
num_samples=1000,
observed_infections=sim_data.sampled_observed_infections,
data_observed_infections=sim_data.observed_infections,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
)
Expand Down
4 changes: 2 additions & 2 deletions model/docs/pyrenew_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ fig, ax = plt.subplots(nrows=3, sharex=True)
ax[0].plot(x.latent_infections)
ax[0].set_ylim([1 / 5, 5])
ax[1].plot(x.latent_hosp_admissions)
ax[2].plot(x.sampled_observed_hosp_admissions, "o")
ax[2].plot(x.observed_hosp_admissions, "o")
for axis in ax[:-1]:
axis.set_yscale("log")
```
Expand All @@ -174,7 +174,7 @@ To fit the `hospmodel` to the simulated data, we call `hospmodel.run()`, an MCMC
hospmodel.run(
num_warmup=1000,
num_samples=1000,
observed_hosp_admissions=x.sampled_observed_hosp_admissions,
data_observed_hosp_admissions=x.observed_hosp_admissions,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
)
Expand Down

0 comments on commit 6cd0583

Please sign in to comment.