Skip to content

Commit

Permalink
Issue 47 - pyrenew_demo.qmd (#77)
Browse files Browse the repository at this point in the history
* Adding text to the demo

* Little change

* Little change

* Little change

* Documented first code block

* Documented first code block

* Edits to first code chunk

* Define model inputs and HospitalizationsModel run

* Complete documentation for pyrenew_demo
  • Loading branch information
cshelley authored Apr 15, 2024
1 parent 08035cc commit c26f35c
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 51 deletions.
132 changes: 89 additions & 43 deletions model/docs/pyrenew_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,25 @@
This demo simulates some basic renewal process data and then fits to it
using `pyrenew`.

You’ll need to install `pyrenew` first. You’ll also need working
installations of `matplotlib`, `numpy`, `jax`, `numpyro`, and `polars`
Assuming you’ve already installed Python and pip, you’ll need to first
install `pyrenew`:

``` python
python3 -m pip install "pyrenew"
```

You’ll also need working installations of `matplotlib`, `numpy`, `jax`,
`numpyro`, and `polars`:

``` python
python -m pip install "matplotlib" "numpy" "jax" "numpyro" "polars"
```

Run the following import section to call external modules and functions
necessary to run the `pyrenew` demo. The `import` statement imports the
module and the `as` statement renames the module for use within this
script. The `from` statement imports a specific function from a module
(named after the `.`) within a package (named before the `.`).

``` python
import matplotlib as mpl
Expand All @@ -20,6 +37,7 @@ import numpyro.distributions as dist
``` python
from pyrenew.process import SimpleRandomWalkProcess
```
To understand the simple random walk process underlying the sampling within the renewal process model, we first examine a single random walk path. Using the `sample` method from an instance of the `SimpleRandomWalkProcess` class, we first create an instance of the `SimpleRandomWalkProcess` class with a normal distribution of mean = 0 and standard deviation = 0.0001 as its input. Next, the `with` statement sets the seed for the random number generator for the duration of the block that follows. Inside the `with` block, the `q_samp = q.sample(duration=100)` generates the sample instance over a duration of 100 time units. Finally, this single random walk process is visualized using `matplot.pyplot` to plot the exponential of the sample instance.

``` python
np.random.seed(3312)
Expand All @@ -32,49 +50,77 @@ plt.plot(np.exp(q_samp[0]))

![](pyrenew_demo_files/figure-commonmark/fig-randwalk-output-1.png)

``` python

Next, import several additional functions from the `latent` module of the `pyrenew` package to model infections, hospital admissions, initial infections, and hospitalization rate due to infection.

```{python}
from pyrenew.latent import (
Infections, HospitalAdmissions, Infections0, InfectHospRate,
)
```

Additionally, import several classes from Pyrenew, including a Poisson observation process, determininstic PMF and variable classes, the Pyrenew hospitalization model, and a renewal modle (Rt) random walk process:

```{python}
from pyrenew.observation import PoissonObservation
from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
from pyrenew.model import HospitalizationsModel
from pyrenew.process import RtRandomWalkProcess
```

To initialize a model run, we first define initial conditions, including:

1) deterministic generation time, defined as an instance of the `DeterministicPMF` class, which gives the probability of each possible outcome for a discrete random variable given as a JAX NumPy array of four possible outcomes

2) initial infections at the start of simulation as a log-normal distribution with mean = 0 and standard deviation = 1

3) latent infections as an instance of the `Infections` class with default settings

4) latent hospitalization process, modeled by first defining the time interval from infections to hospitalizations as a `DeterministicPMF` input with 18 possible outcomes and corresponding probabilities given by the values in the array. The `HospitalAdmissions` function then takes in this defined time interval, as well as defining the rate at which infections are admitted to the hospital due to infection, modeled as a log-normal distribution with mean = `jnp.log(0.05)` and standard deviation = 0.05.

5) hospitalization observation process, modeled with a Poisson distribution

6) an Rt random walk process with default settings

``` python
# Initializing model components:

# A deterministic generation time
# 1) A deterministic generation time
gen_int = DeterministicPMF(
(jnp.array([0.25, 0.25, 0.25, 0.25]),),
)

# Initial infections
# 2) Initial infections
I0 = Infections0(I0_dist=dist.LogNormal(0, 1))

# The latent infections process
# 3) The latent infections process
latent_infections = Infections()

# A deterministic infection to hosp pmf
# 4) The latent hospitalization process:

# First, define a deterministic infection to hosp pmf
inf_hosp_int = DeterministicPMF(
(jnp.array([0, 0, 0,0,0,0,0,0,0,0,0,0,0, 0.25, 0.5, 0.1, 0.1, 0.05]),),
)

# The latent hospitalization process
latent_hospitalizations = HospitalAdmissions(
infection_to_admission_interval=inf_hosp_int,
infect_hosp_rate_dist = InfectHospRate(
dist=dist.LogNormal(jnp.log(0.05), 0.05),
),
)

# And observation process for the hospitalizations
# 5) An observation process for the hospitalizations
observed_hospitalizations = PoissonObservation()

# And a random walk process (it could be deterministic using
# 6) A random walk process (it could be deterministic using
# pyrenew.process.DeterministicProcess())
Rt_process = RtRandomWalkProcess()
```

The `HospitalizationsModel` is then initialized using the initial conditions just defined:

``` python
# Initializing the model
hospmodel = HospitalizationsModel(
gen_int=gen_int,
Expand Down Expand Up @@ -142,39 +188,39 @@ hospmodel.print_summary()


mean std median 5.0% 95.0% n_eff r_hat
I0 1.27 1.10 0.97 0.10 2.42 1132.34 1.00
IHR 0.05 0.00 0.05 0.05 0.05 2306.45 1.00
Rt0 1.23 0.17 1.23 0.93 1.48 1327.22 1.00
Rt_transformed_rw_diffs[0] -0.00 0.02 -0.00 -0.04 0.04 1404.95 1.00
Rt_transformed_rw_diffs[1] 0.00 0.03 0.00 -0.04 0.04 2280.86 1.00
Rt_transformed_rw_diffs[2] -0.00 0.02 -0.00 -0.04 0.04 2119.83 1.00
Rt_transformed_rw_diffs[3] 0.00 0.02 -0.00 -0.04 0.04 2196.86 1.00
Rt_transformed_rw_diffs[4] 0.00 0.02 -0.00 -0.03 0.04 2391.45 1.00
Rt_transformed_rw_diffs[5] 0.00 0.03 0.00 -0.04 0.04 2043.02 1.00
Rt_transformed_rw_diffs[6] 0.00 0.02 0.00 -0.04 0.04 1514.40 1.00
Rt_transformed_rw_diffs[7] -0.00 0.02 -0.00 -0.04 0.04 2619.69 1.00
Rt_transformed_rw_diffs[8] 0.00 0.03 0.00 -0.04 0.04 1883.84 1.00
Rt_transformed_rw_diffs[9] 0.00 0.03 0.00 -0.04 0.04 2015.66 1.00
Rt_transformed_rw_diffs[10] 0.00 0.02 0.00 -0.04 0.04 2045.47 1.00
Rt_transformed_rw_diffs[11] -0.00 0.03 0.00 -0.04 0.04 1615.10 1.00
Rt_transformed_rw_diffs[12] 0.00 0.02 0.00 -0.04 0.04 2206.32 1.00
Rt_transformed_rw_diffs[13] 0.00 0.03 0.00 -0.04 0.04 1175.93 1.00
Rt_transformed_rw_diffs[14] -0.00 0.03 -0.00 -0.04 0.04 1606.26 1.00
Rt_transformed_rw_diffs[15] -0.00 0.03 -0.00 -0.04 0.04 2344.62 1.00
Rt_transformed_rw_diffs[16] -0.00 0.02 0.00 -0.04 0.04 1522.33 1.00
Rt_transformed_rw_diffs[17] 0.00 0.03 0.00 -0.04 0.04 2157.17 1.00
Rt_transformed_rw_diffs[18] -0.00 0.02 -0.00 -0.04 0.04 1594.95 1.00
Rt_transformed_rw_diffs[19] 0.00 0.03 -0.00 -0.04 0.04 1698.70 1.00
Rt_transformed_rw_diffs[20] 0.00 0.02 0.00 -0.04 0.04 1726.18 1.00
Rt_transformed_rw_diffs[21] 0.00 0.02 -0.00 -0.04 0.04 2386.35 1.00
Rt_transformed_rw_diffs[22] 0.00 0.03 0.00 -0.04 0.04 2028.63 1.00
Rt_transformed_rw_diffs[23] 0.00 0.02 0.00 -0.04 0.03 1669.71 1.00
Rt_transformed_rw_diffs[24] 0.00 0.02 0.00 -0.04 0.04 2126.33 1.00
Rt_transformed_rw_diffs[25] -0.00 0.02 -0.00 -0.04 0.04 2119.74 1.00
Rt_transformed_rw_diffs[26] 0.00 0.03 0.00 -0.04 0.04 2657.91 1.00
Rt_transformed_rw_diffs[27] -0.00 0.03 0.00 -0.04 0.04 1939.30 1.00
Rt_transformed_rw_diffs[28] -0.00 0.02 -0.00 -0.04 0.04 1737.84 1.00
Rt_transformed_rw_diffs[29] -0.00 0.03 -0.00 -0.04 0.04 2105.55 1.00
I0 1.26 1.09 0.96 0.09 2.41 1114.64 1.00
IHR 0.05 0.00 0.05 0.05 0.05 2747.53 1.00
Rt0 1.23 0.17 1.23 0.95 1.52 1533.77 1.00
Rt_transformed_rw_diffs[0] 0.00 0.03 -0.00 -0.04 0.04 1574.38 1.00
Rt_transformed_rw_diffs[1] 0.00 0.03 0.00 -0.04 0.04 2557.65 1.00
Rt_transformed_rw_diffs[2] 0.00 0.02 0.00 -0.04 0.04 2245.16 1.00
Rt_transformed_rw_diffs[3] 0.00 0.02 0.00 -0.03 0.04 2423.80 1.00
Rt_transformed_rw_diffs[4] 0.00 0.02 0.00 -0.03 0.04 2461.65 1.00
Rt_transformed_rw_diffs[5] 0.00 0.02 0.00 -0.04 0.04 2363.57 1.00
Rt_transformed_rw_diffs[6] 0.00 0.02 0.00 -0.04 0.04 1720.93 1.00
Rt_transformed_rw_diffs[7] -0.00 0.02 -0.00 -0.04 0.04 3851.66 1.00
Rt_transformed_rw_diffs[8] -0.00 0.03 -0.00 -0.04 0.04 1824.74 1.00
Rt_transformed_rw_diffs[9] 0.00 0.02 0.00 -0.04 0.04 1739.32 1.00
Rt_transformed_rw_diffs[10] 0.00 0.02 0.00 -0.04 0.04 1944.43 1.00
Rt_transformed_rw_diffs[11] -0.00 0.03 0.00 -0.04 0.04 1558.14 1.00
Rt_transformed_rw_diffs[12] 0.00 0.02 0.00 -0.04 0.04 2182.35 1.00
Rt_transformed_rw_diffs[13] -0.00 0.03 0.00 -0.04 0.04 1175.35 1.00
Rt_transformed_rw_diffs[14] -0.00 0.03 -0.00 -0.04 0.04 1540.25 1.00
Rt_transformed_rw_diffs[15] -0.00 0.03 -0.00 -0.04 0.04 2367.82 1.00
Rt_transformed_rw_diffs[16] -0.00 0.02 -0.00 -0.04 0.04 1636.30 1.00
Rt_transformed_rw_diffs[17] 0.00 0.03 0.00 -0.04 0.04 1978.96 1.00
Rt_transformed_rw_diffs[18] 0.00 0.02 -0.00 -0.04 0.04 1589.27 1.00
Rt_transformed_rw_diffs[19] 0.00 0.03 -0.00 -0.04 0.04 1691.06 1.00
Rt_transformed_rw_diffs[20] -0.00 0.02 -0.00 -0.04 0.04 2562.99 1.00
Rt_transformed_rw_diffs[21] 0.00 0.02 -0.00 -0.04 0.04 2352.40 1.00
Rt_transformed_rw_diffs[22] 0.00 0.03 0.00 -0.04 0.04 1971.40 1.00
Rt_transformed_rw_diffs[23] 0.00 0.02 0.00 -0.04 0.04 2013.90 1.00
Rt_transformed_rw_diffs[24] 0.00 0.03 0.00 -0.04 0.04 2022.94 1.00
Rt_transformed_rw_diffs[25] -0.00 0.02 -0.00 -0.04 0.03 1981.62 1.00
Rt_transformed_rw_diffs[26] 0.00 0.03 0.00 -0.04 0.05 2696.36 1.00
Rt_transformed_rw_diffs[27] -0.00 0.03 0.00 -0.04 0.04 2003.38 1.00
Rt_transformed_rw_diffs[28] -0.00 0.02 -0.00 -0.04 0.04 1843.27 1.00
Rt_transformed_rw_diffs[29] -0.00 0.03 -0.00 -0.04 0.04 1780.88 1.00

Number of divergences: 0

Expand Down
69 changes: 61 additions & 8 deletions model/docs/pyrenew_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,20 @@ engine: jupyter

This demo simulates some basic renewal process data and then fits to it using `pyrenew`.

You'll need to install `pyrenew` first. You'll also need working installations of `matplotlib`, `numpy`, `jax`, `numpyro`, and `polars`
Assuming you've already installed Python and pip, you’ll need to first install `pyrenew`:

```{python}
pip install pyrenew
```

You’ll also need working
installations of `matplotlib`, `numpy`, `jax`, `numpyro`, and `polars`:

```{python}
pip install matplotlib numpy jax numpyro polars
```

To begin, run the following import section to call external modules and functions necessary to run the `pyrenew` demo. The `import` statement imports the module and the `as` statement renames the module for use within this script. The `from` statement imports a specific function from a module (named after the `.`) within a package (named before the `.`).

```{python}
#| output: false
Expand All @@ -25,6 +38,8 @@ import numpyro.distributions as dist
from pyrenew.process import SimpleRandomWalkProcess
```

To understand the simple random walk process underlying the sampling within the renewal process model, we first examine a single random walk path. Using the `sample` method from an instance of the `SimpleRandomWalkProcess` class, we first create an instance of the `SimpleRandomWalkProcess` class with a normal distribution of mean = 0 and standard deviation = 0.0001 as its input. Next, the `with` statement sets the seed for the random number generator for the duration of the block that follows. Inside the `with` block, the `q_samp = q.sample(duration=100)` generates the sample instance over a duration of 100 time units. Finally, this single random walk process is visualized using `matplot.pyplot` to plot the exponential of the sample instance.

```{python}
#| label: fig-randwalk
#| fig-cap: Random walk example
Expand All @@ -36,49 +51,76 @@ with seed(rng_seed=np.random.randint(0,1000)):
plt.plot(np.exp(q_samp[0]))
```

Next, import several additional functions from the `latent` module of the `pyrenew` package to model infections, hospital admissions, initial infections, and hospitalization rate due to infection.

```{python}
from pyrenew.latent import (
Infections, HospitalAdmissions, Infections0, InfectHospRate,
)
```

Additionally, import several classes from Pyrenew, including a Poisson observation process, determininstic PMF and variable classes, the Pyrenew hospitalization model, and a renewal modle (Rt) random walk process:

```{python}
from pyrenew.observation import PoissonObservation
from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
from pyrenew.model import HospitalizationsModel
from pyrenew.process import RtRandomWalkProcess
```

To initialize the model, we first define initial conditions, including:

1) deterministic generation time, defined as an instance of the `DeterministicPMF` class, which gives the probability of each possible outcome for a discrete random variable given as a JAX NumPy array of four possible outcomes

2) initial infections at the start of simulation as a log-normal distribution with mean = 0 and standard deviation = 1

3) latent infections as an instance of the `Infections` class with default settings

4) latent hospitalization process, modeled by first defining the time interval from infections to hospitalizations as a `DeterministicPMF` input with 18 possible outcomes and corresponding probabilities given by the values in the array. The `HospitalAdmissions` function then takes in this defined time interval, as well as defining the rate at which infections are admitted to the hospital due to infection, modeled as a log-normal distribution with mean = `jnp.log(0.05)` and standard deviation = 0.05.

5) hospitalization observation process, modeled with a Poisson distribution

6) an Rt random walk process with default settings

```{python}
# Initializing model components:
# A deterministic generation time
# 1) A deterministic generation time
gen_int = DeterministicPMF(
(jnp.array([0.25, 0.25, 0.25, 0.25]),),
)
# Initial infections
# 2) Initial infections
I0 = Infections0(I0_dist=dist.LogNormal(0, 1))
# The latent infections process
# 3) The latent infections process
latent_infections = Infections()
# A deterministic infection to hosp pmf
# 4) The latent hospitalization process:
# First, define a deterministic infection to hosp pmf
inf_hosp_int = DeterministicPMF(
(jnp.array([0, 0, 0,0,0,0,0,0,0,0,0,0,0, 0.25, 0.5, 0.1, 0.1, 0.05]),),
)
# The latent hospitalization process
latent_hospitalizations = HospitalAdmissions(
infection_to_admission_interval=inf_hosp_int,
infect_hosp_rate_dist = InfectHospRate(
dist=dist.LogNormal(jnp.log(0.05), 0.05),
),
)
# And observation process for the hospitalizations
# 5) An observation process for the hospitalizations
observed_hospitalizations = PoissonObservation()
# And a random walk process (it could be deterministic using
# 6) A random walk process (it could be deterministic using
# pyrenew.process.DeterministicProcess())
Rt_process = RtRandomWalkProcess()
```

The `HospitalizationsModel` is then initialized using the initial conditions just defined:

```{python}
# Initializing the model
hospmodel = HospitalizationsModel(
gen_int=gen_int,
Expand All @@ -90,13 +132,16 @@ hospmodel = HospitalizationsModel(
)
```

Next, we sample from the `hospmodel` for 30 time steps and view the output of a single run:

```{python}
with seed(rng_seed=np.random.randint(1, 60)):
x = hospmodel.sample(n_timepoints=30)
x
```

Visualizations of the single model output show (top) infections over the 30 time steps, (middle) hospitalizations over the 30 time steps, and (bottom)

```{python}
#| label: fig-hosp
#| fig-cap: Infections
Expand All @@ -109,6 +154,8 @@ for axis in ax[:-1]:
axis.set_yscale("log")
```

To fit the `hospmodel` to the simulated data, we call `hospmodel.run()`, an MCMC algorithm, with the arguments generated in `hospmodel` object, using 1000 warmup stepts and 1000 samples to draw from the posterior distribution of the model parameters. The model is run for `len(x.sampled)-1` time steps with the seed set by `jax.random.PRNGKey()`

```{python}
# from numpyro.infer import MCMC, NUTS
hospmodel.run(
Expand All @@ -121,15 +168,21 @@ hospmodel.run(
)
```

Print a summary of the model:

```{python}
hospmodel.print_summary()
```

Next, we will use the `spread_draws` function from the `pyrenew.mcmcutils` module to process the MCMC samples. The `spread_draws` function reformats the samples drawn from the `mcmc.get_samples()` from the `hospmodel`. The samples are simulated Rt values over time.

```{python}
from pyrenew.mcmcutils import spread_draws
samps = spread_draws(hospmodel.mcmc.get_samples(), [("Rt", "time")])
```

We visualize these samples below, with individual possible Rt estimates over time shown in light blue, and the overall mean estimate Rt shown in dark blue.

```{python}
#| label: fig-sampled-rt
#| fig-cap: Posterior Rt
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit c26f35c

Please sign in to comment.