Skip to content

Commit

Permalink
Update tutorials formatting (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
damonbayer authored Jun 27, 2024
1 parent 5748007 commit ae4424f
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 86 deletions.
64 changes: 35 additions & 29 deletions model/docs/example_with_datasets.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ We can use [ArviZ](https://www.arviz.org/) to visualize the results. Let's start
```{python}
# | label: convert-inferenceData
import arviz as az
idata = az.from_numpyro(hosp_model.mcmc)
```
We obtain the summary of model diagnostics and print the diagnostics for `latent_hospital_admissions[1]`
Expand All @@ -314,48 +315,50 @@ We obtain the summary of model diagnostics and print the diagnostics for `latent
# | warning: false
diagnostic_stats_summary = az.summary(
idata.posterior,
kind='diagnostics',
)
kind="diagnostics",
)
print(diagnostic_stats_summary.loc['latent_hospital_admissions[1]'])
print(diagnostic_stats_summary.loc["latent_hospital_admissions[1]"])
```

Below we plot 90% and 50% highest density intervals for latent hospital admissions using [plot_hdi](https://python.arviz.org/en/stable/api/generated/arviz.plot_hdi.html):

```{python}
# | label: fig-output-admission-distribution
# | fig-cap: Hospital Admissions posterior distribution
x_data = idata.posterior['latent_hospital_admissions_dim_0']
y_data = idata.posterior['latent_hospital_admissions']
x_data = idata.posterior["latent_hospital_admissions_dim_0"]
y_data = idata.posterior["latent_hospital_admissions"]
fig, axes = plt.subplots(figsize=(6,5))
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.9,
color='C0',
color="C0",
smooth=False,
fill_kwargs={'alpha':0.3},
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.5,
color='C0',
color="C0",
smooth=False,
fill_kwargs={'alpha':0.6},
fill_kwargs={"alpha": 0.6},
ax=axes,
)
#Add mean of the posterior to the figure
mean_latent_hosp_admission = np.mean(idata.posterior['latent_hospital_admissions'],axis=1)
axes.plot(x_data,mean_latent_hosp_admission[0], color='C0', label='Mean')
# Add mean of the posterior to the figure
mean_latent_hosp_admission = np.mean(
idata.posterior["latent_hospital_admissions"], axis=1
)
axes.plot(x_data, mean_latent_hosp_admission[0], color="C0", label="Mean")
axes.legend()
axes.set_title('Posterior Hospital Admissions', fontsize=10)
axes.set_xlabel('Time', fontsize=10)
axes.set_ylabel('Hospital Admissions',fontsize=10);
axes.set_title("Posterior Hospital Admissions", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Hospital Admissions", fontsize=10);
```

We can also take a look at the latent infections:
Expand All @@ -372,37 +375,39 @@ and the distribution of latent infections
```{python}
# | label: fig-output-infections-distribution
# | fig-cap: Posterior Latent Infections
x_data = idata.posterior['all_latent_infections_dim_0']
y_data = idata.posterior['all_latent_infections']
x_data = idata.posterior["all_latent_infections_dim_0"]
y_data = idata.posterior["all_latent_infections"]
fig, axes = plt.subplots(figsize=(6,5))
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.9,
color='C0',
color="C0",
smooth=False,
fill_kwargs={'alpha':0.3},
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.5,
color='C0',
color="C0",
smooth=False,
fill_kwargs={'alpha':0.6},
fill_kwargs={"alpha": 0.6},
ax=axes,
)
#Add mean of the posterior to the figure
mean_latent_infection = np.mean(idata.posterior['all_latent_infections'],axis=1)
axes.plot(x_data,mean_latent_infection[0], color='C0', label='Mean')
# Add mean of the posterior to the figure
mean_latent_infection = np.mean(
idata.posterior["all_latent_infections"], axis=1
)
axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean")
axes.legend()
axes.set_title('Posterior Latent Infections', fontsize=10)
axes.set_xlabel('Time', fontsize=10)
axes.set_ylabel('Latent Infections',fontsize=10);
axes.set_title("Posterior Latent Infections", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Latent Infections", fontsize=10);
```

## Round 2: Incorporating day-of-the-week effects
Expand All @@ -415,6 +420,7 @@ Note a similar weekday effect is implemented in its own module, with example cod
from pyrenew import metaclass
import numpyro as npro
class DayOfWeekEffect(metaclass.RandomVariable):
"""Day of the week effect"""
Expand Down
87 changes: 45 additions & 42 deletions model/docs/getting_started.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -199,42 +199,43 @@ model1.run(
Now, let's investigate the output, particularly the posterior distribution of the $R_t$ estimates:

```{python}
#| label: fig-output-rt
#| fig-cap: Rt posterior distribution
# | label: fig-output-rt
# | fig-cap: Rt posterior distribution
out = model1.plot_posterior(var="Rt")
```

We can use [ArviZ](https://www.arviz.org/) package to create model diagnostics and visualizations. We start by converting the fitted model to ArviZ InferenceData object:

```{python}
#| label: convert-inference-data
# | label: convert-inference-data
import arviz as az
idata = az.from_numpyro(model1.mcmc)
```

and use the InferenceData to compute the model-fit diagnostics. Here, we show diagnostic summary for the first 10 effective reproduction number $R_t$.

```{python}
#| label: diagnostics
# | label: diagnostics
diagnostic_stats_summary = az.summary(
idata.posterior['Rt'],
kind='diagnostics',
)
idata.posterior["Rt"],
kind="diagnostics",
)
print(diagnostic_stats_summary[:10])
```

Below we use `plot_trace` to inspect the trace of the first 10 $R_t$ estimates.

```{python}
#| label: fig-trace-Rt
#| fig-cap: Trace plot of Rt posterior distribution
plt.rcParams['figure.constrained_layout.use'] = True
# | label: fig-trace-Rt
# | fig-cap: Trace plot of Rt posterior distribution
plt.rcParams["figure.constrained_layout.use"] = True
az.plot_trace(
idata.posterior,
var_names=['Rt'],
coords={'Rt_dim_0': np.arange(10)},
var_names=["Rt"],
coords={"Rt_dim_0": np.arange(10)},
compact=False,
);
```
Expand All @@ -243,75 +244,77 @@ az.plot_trace(
We inspect the posterior distribution of $R_t$ by plotting the 90% and 50% highest density intervals:

```{python}
#| label: fig-hdi-Rt
#| fig-cap: High density interval for Effective Reproduction Number
x_data = idata.posterior['Rt_dim_0']
y_data = idata.posterior['Rt']
# | label: fig-hdi-Rt
# | fig-cap: High density interval for Effective Reproduction Number
x_data = idata.posterior["Rt_dim_0"]
y_data = idata.posterior["Rt"]
fig, axes = plt.subplots(figsize=(6,5))
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.9,
color='C0',
fill_kwargs={'alpha':0.3},
color="C0",
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.5,
color='C0',
fill_kwargs={'alpha':0.6},
color="C0",
fill_kwargs={"alpha": 0.6},
ax=axes,
)
#Add mean of the posterior to the figure
mean_Rt = np.mean(idata.posterior['Rt'],axis=1)
axes.plot(x_data,mean_Rt[0], color='C0', label='Mean')
# Add mean of the posterior to the figure
mean_Rt = np.mean(idata.posterior["Rt"], axis=1)
axes.plot(x_data, mean_Rt[0], color="C0", label="Mean")
axes.legend()
axes.set_title('Posterior Effective Reproduction Number', fontsize=10)
axes.set_xlabel('Time', fontsize=10)
axes.set_ylabel('$R_t$', fontsize=10);
axes.set_title("Posterior Effective Reproduction Number", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("$R_t$", fontsize=10);
```

and latent infections:

```{python}
#| label: fig-hdi-latent-infections
#| fig-cap: High density interval for Latent Infections
x_data = idata.posterior['all_latent_infections_dim_0']
y_data = idata.posterior['all_latent_infections']
# | label: fig-hdi-latent-infections
# | fig-cap: High density interval for Latent Infections
x_data = idata.posterior["all_latent_infections_dim_0"]
y_data = idata.posterior["all_latent_infections"]
fig, axes = plt.subplots(figsize=(6,5))
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.9,
color='C0',
color="C0",
smooth=False,
fill_kwargs={'alpha':0.3},
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.5,
color='C0',
color="C0",
smooth=False,
fill_kwargs={'alpha':0.6},
fill_kwargs={"alpha": 0.6},
ax=axes,
)
#Add mean of the posterior to the figure
mean_latent_infection = np.mean(idata.posterior['all_latent_infections'],axis=1)
axes.plot(x_data,mean_latent_infection[0], color='C0', label='Mean')
# Add mean of the posterior to the figure
mean_latent_infection = np.mean(
idata.posterior["all_latent_infections"], axis=1
)
axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean")
axes.legend()
axes.set_title('Posterior Latent Infections',fontsize=10)
axes.set_xlabel('Time', fontsize=10)
axes.set_ylabel('Latent Infections',fontsize=10);
axes.set_title("Posterior Latent Infections", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Latent Infections", fontsize=10);
```

## Architecture of pyrenew
Expand Down
24 changes: 12 additions & 12 deletions model/docs/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ from pyrenew import process, deterministic
```{python}
# The random process for Rt
rt_proc = process.RtWeeklyDiffProcess(
offset = 0,
log_rt_prior = deterministic.DeterministicVariable(
jnp.array([0.1, 0.2]), name="log_rt_prior"
),
autoreg = deterministic.DeterministicVariable(
jnp.array([0.7]), name="autoreg"
offset=0,
log_rt_prior=deterministic.DeterministicVariable(
jnp.array([0.1, 0.2]), name="log_rt_prior"
),
autoreg=deterministic.DeterministicVariable(
jnp.array([0.7]), name="autoreg"
),
periodic_diff_sd=deterministic.DeterministicVariable(
jnp.array([0.1]), name="periodic_diff_sd"
),
periodic_diff_sd = deterministic.DeterministicVariable(
jnp.array([0.1]), name="periodic_diff_sd"
),
)
```

Expand Down Expand Up @@ -69,13 +69,13 @@ from pyrenew import transformation, metaclass
# Building the transformed prior: Dirichlet * 7
mysimplex = dist.TransformedDistribution(
dist.Dirichlet(concentration=jnp.ones(7)),
transformation.AffineTransform(loc=0, scale=7.0)
transformation.AffineTransform(loc=0, scale=7.0),
)
# Constructing the day of week effect
dayofweek = process.DayOfWeekEffect(
offset = 0,
quantity_to_broadcast=metaclass.DistributionalRV(mysimplex, "simp")
offset=0,
quantity_to_broadcast=metaclass.DistributionalRV(mysimplex, "simp"),
)
```

Expand Down
6 changes: 3 additions & 3 deletions model/docs/pyrenew_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ admissions_process = PoissonObservation()
# 6) A random walk process (it could be deterministic using
# pyrenew.process.DeterministicProcess())
Rt_process = RtRandomWalkProcess(
Rt0_dist = dist.TruncatedNormal(loc=1.2, scale=0.2, low=0),
Rt_transform = t.ExpTransform().inv,
Rt_rw_dist = dist.Normal(0, 0.025),
Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0),
Rt_transform=t.ExpTransform().inv,
Rt_rw_dist=dist.Normal(0, 0.025),
)
```

Expand Down

0 comments on commit ae4424f

Please sign in to comment.