Skip to content

Commit

Permalink
Formally addresses implementing the observed hospitalizations module (#…
Browse files Browse the repository at this point in the history
…39)

* Copy doc from wastewater model on Hosp Admin to class

* Rename metaclasses to metaclass (#36)

* Create typos.yaml (#42)

* Create typos.yaml

* Does this make it run on the whole repo?

* remove trailing whitespace

* Delete .github/workflows/typos.yaml

trying to use typos in pre-commit instead

* Update .pre-commit-config.yaml

add typos to pre-commit

* fixed typos

* 34 port notebooks under modeldocs to quarto (#35)

* Porting to quarto and adding a makefile entry to render them

* Now it should be easier to deal with pre-commit under docs (only skipping cache)

* Forgot to remove the pyrenew_demo notebook

* Update model/Makefile

* Fixing typos

* Removes progress bar - adds doc/ figures - sets jupyter as kernel

* Update model/docs/getting-started.qmd

Co-authored-by: Dylan H. Morris <[email protected]>

---------

Co-authored-by: Dylan H. Morris <[email protected]>

* Extra mathematical description of discrete delay distributions (#44)

* Update description of discrete delay distributions

* remove double desciption

* minor eq fix

* Update equations.md

* update equations.md contents

* Update equations.md

* fix contents

* Escaping tau

* Adding deterministic obs and process to the equation

* Cleaning the quarto documents and working on the getting started diagram

* Flexible IHR (now RandomVariable)

* Adding weekday and phosp effect to latent hosp

* Adding back figures

* Adding a test for deterministic/stochastic weekday effect

* Typo

Co-authored-by: Dylan H. Morris <[email protected]>

* Correcting tests (class name) and improving readme a bit

* Adding deterministic module (midway, expected to fail) [skip ci]

* Refactoring I0 and gen_int (expected to fail) [skip ci]

* gen_int and I0 now are directly passed to the models

* In latent hosp, change inf_hosp_int to inform_hosp (clearer name)

* Adding missing figures (pyrenew demo was not compiling)

* Renaming inform_hosp

* Removing defaults for hosp rate

* Changing language (initial infections) + adding section to getting-started

* Update model/src/pyrenew/latent/hospitaladmissions.py

Co-authored-by: Dylan H. Morris <[email protected]>

* Addressing comments on default priors and varnames

* Rt is not default now for basic model

* Commas and title

* Update model/src/pyrenew/latent/hospitaladmissions.py

Co-authored-by: Dylan H. Morris <[email protected]>

* Renaming hosp reporting variable in latent var

* Renaming hosp report

* Update model/src/pyrenew/latent/hospitaladmissions.py

Co-authored-by: Dylan H. Morris <[email protected]>

* Final renaming of vars in tests

* Different vector for hosp_report_prob_dist in tests

---------

Co-authored-by: George G. Vega Yon <[email protected]>
Co-authored-by: George G Vega Yon <[email protected]>
Co-authored-by: Nate McIntosh <[email protected]>
Co-authored-by: Dylan H. Morris <[email protected]>
Co-authored-by: Samuel Brand <[email protected]>
Co-authored-by: George G. Vega Yon <[email protected]>
Co-authored-by: Dylan H. Morris <[email protected]>
  • Loading branch information
7 people authored Apr 3, 2024
1 parent cf487d3 commit 5a06ab2
Show file tree
Hide file tree
Showing 26 changed files with 1,322 additions and 362 deletions.
11 changes: 8 additions & 3 deletions model/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@ test:
docs: docs/pyrenew_demo.md docs/getting-started.md

docs/pyrenew_demo.md: docs/pyrenew_demo.qmd
quarto render docs/pyrenew_demo.qmd
poetry run quarto render docs/pyrenew_demo.qmd

docs/getting-started.md: docs/getting-started.qmd
quarto render docs/getting-started.qmd
poetry run quarto render docs/getting-started.qmd

.PHONY: install test docs
clean:
rm -rf docs/*_files/
rm -f docs/getting-started.ipynb
rm -f docs/pyrenew_demo.ipynb

.PHONY: install test docs clean
9 changes: 6 additions & 3 deletions model/README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# PyRenew
# PyRenew: A Package for Bayesian Renewal Modeling with JAX and Numpyro.

A package for Bayesian renewal modeling with JAX and Numpyro.
`pyrenew` is a flexible tool for simulating and statistical inference of epidemiological models, emphasizing renewal models. Built on top of the [`numpyro`](https://num.pyro.ai/) Python library, `pyrenew` provides core components for model building, including pre-defined models for processing various types of observational processes.

## Installation

Install via pip with

```bash
pip install git+https://github.com/cdcent/cfa-pyrenew.git
```

## Demo
The `docs` folder contains a Jupyter notebook with an interactive demo to get you started. It simulates observed hospitalizations using a simple renewal process model and then fits to it using a No-U-Turn Sampler.

The [`docs`](docs) folder contains quarto documents to get you started. It simulates observed hospitalizations using a simple renewal process model and then fits it using a No-U-Turn Sampler.
1 change: 1 addition & 0 deletions model/docs/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
!*png
*_files/libs
289 changes: 182 additions & 107 deletions model/docs/getting-started.md
Original file line number Diff line number Diff line change
@@ -1,127 +1,142 @@
# Getting started with pyrenew


This document illustrates two features of `pyrenew`: (a) the set of
included `RandomVariable`s, and (b) model composition.
`pyrenew` is a flexible tool for simulating and making statistical
inference of epidemiological models, emphasizing renewal models. Built
on `numpyro`, `pyrenew` provides core components for model building and
pre-defined models for processing various observational processes. This
document illustrates how `pyrenew` can be used to build a basic renewal
model.

## The fundamentals

`pyrenew`’s core components are the metaclasses `RandomVariable` and
`Model`. From the package’s perspective, a `RandomVariable` is a
quantity models can sample and estimate, **including deterministic
quantities**. Mainly, sampling from a `RandomVariable` involves calling
the `sample()` method. The benefit of this design is the definition of
the sample function can be arbitrary, allowing the user to either sample
from a distribution using `numpyro.sample()`, compute fixed quantities
(like a mechanistic equation), or return a fixed value (like a
pre-computed PMF.) For instance, we may be interested in estimating a
PMF, in which case a `RandomVariable` sampling function may roughly be
defined as:

## Hospitalizations model

`pyrenew` has five main components:

- Utility and math functions,
- The `processes` sub-module,
- The `observations` sub-module,
- The `latent` sub-module, and
- The `models` sub-module

All three of `process`, `observation`, and `latent` contain classes that
inherit from the meta class `RandomVariable`. The classes under `model`
inherit from the meta class `Model`. The following diagram illustrates
the composition the model `pyrenew.models.HospitalizationsModel`:

``` mermaid
flowchart TB
subgraph randprocmod["Processes module"]
direction TB
simprw["SimpleRandomWalkProcess"]
rtrw["RtRandomWalkProcess"]
end
subgraph latentmod["Latent module"]
direction TB
hosp_latent["Hospitalizations"]
inf_latent["Infections"]
end
subgraph obsmod["Observations module"]
direction TB
pois["PoissonObservation"]
nb["NegativeBinomialObservation"]
end
subgraph models["Models module"]
direction TB
basic["RtInfectionsRenewalModel"]
hosp["HospitalizationsModel"]
end
rp(("RandomVariable")) --> |Inherited by| randprocmod
rp -->|Inherited by| latentmod
rp -->|Inherited by| obsmod
model(("Model")) -->|Inherited by| models
``` python
class MyRandVar(RandomVariable):
def sample(...) -> ArrayLike:
return numpyro.sample(...)
```

simprw -->|Composes| rtrw
rtrw -->|Composes| basic
inf_latent -->|Composes| basic
basic -->|Composes| hosp
Whereas, in some other cases, we may instead use a fixed quantity for
that variable (like a pre-computed PMF), where the `RandomVariable`’s
sample function could be defined like:

``` python
class MyRandVar(RandomVariable):
def sample(...) -> ArrayLike:
return jax.numpy.array([0.2, 0.7, 0.1])
```

obsmod -->|Composes|models
hosp_latent -->|Composes| hosp
This way, when a `Model` samples from `MyRandVar`, it could be either
adding random variables to be estimated (first case) or just retrieving
some quantity needed for other calculations (second case.)

%% Metaclasses
classDef Metaclass color:black,fill:white
class rp,model Metaclass
The `Model` metaclass provides basic functionality for estimating and
simulation. Like `RandomVariable`, the `Model` metaclass has a
`sample()` method that defines the model structure. Ultimately, models
can be nested (or inherited), providing a straightforward way to add
layers of complexity.

%% Random process
classDef Randproc fill:purple,color:white
class rtrw,simprw Randproc
## ‘Hello world’ model

%% Models
classDef Models fill:teal,color:white
class basic,hosp Models
```

We start by loading the needed components to build a basic renewal
model:
This section will show the steps to build a simple renewal model
featuring a latent infection process, a random walk Rt process, and an
observation process for the reported infections. We start by loading the
needed components to build a basic renewal model:

``` python
import jax.numpy as jnp
import numpy as np
import numpyro as npro
import numpyro.distributions as dist
from pyrenew.process import RtRandomWalkProcess
from pyrenew.latent import Infections
from pyrenew.latent import Infections, Infections0
from pyrenew.observation import PoissonObservation
from pyrenew.deterministic import DeterministicPMF
from pyrenew.model import RtInfectionsRenewalModel
```

/mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

In the basic renewal model we can define three components: Rt, latent
infections, and observed infections.
The basic renewal model defines five components: generation interval,
initial infections, Rt, latent infections, and observed infections. In
this example, the generation interval is not estimated but passed as a
deterministic instance of `RandomVariable`. Here is the code to
initialize the five components:

``` python
latent_infections = Infections(
gen_int=jnp.array([0.25, 0.25, 0.25, 0.25]),
)
# (1) The generation interval (deterministic)
gen_int = DeterministicPMF(
(jnp.array([0.25, 0.25, 0.25, 0.25]),),
)

observed_infections = PoissonObservation(
rate_varname='latent',
counts_varname='observed_infections',
)
# (2) Initial infections (inferred with a prior)
I0 = Infections0(I0_dist=dist.LogNormal(0, 1))

# (3) The random process for Rt
rt_proc = RtRandomWalkProcess()

# (4) Latent infection process (which will use 1 and 2)
latent_infections = Infections()

# (5) The observed infections process (with mean at the latent infections)
observed_infections = PoissonObservation(
rate_varname = 'latent',
counts_varname = 'observed_infections',
)
```

With observation process for the latent infections, we can build the
basic renewal model, and generate a sample calling the `sample()`
method:
With these five pieces, we can build the basic renewal model:

``` python
model1 = RtInfectionsRenewalModel(
Rt_process=rt_proc,
latent_infections=latent_infections,
observed_infections=observed_infections,
gen_int = gen_int,
I0 = I0,
Rt_process = rt_proc,
latent_infections = latent_infections,
observed_infections = observed_infections,
)
```

The following diagram summarizes how the modules interact via
composition; notably, `gen_int`, `I0`, `rt_proc`, `latent_infections`,
and `observed_infections` are instances of `RandomVariable`, which means
these can be easily replaced to generate a different version of
`RtInfectionsRenewalModel`:

``` mermaid
flowchart TB
genint["(1) gen_int\n(DetermnisticPMF)"]
i0["(2) I0\n(Infections0)"]
rt["(3) rt_proc\n(RtRandomWalkProcess)"]
inf["(4) latent_infections\n(Infections)"]
obs["(5) observed_infections\n(PoissonObservation)"]
model1["model1\n(RtInfectionsRenewalModel)"]
i0-->|Composes|model1
genint-->|Composes|model1
rt-->|Composes|model1
obs-->|Composes|model1
inf-->|Composes|model1
```

Using `numpyro`, we can simulate data using the `sample()` member
function of `RtInfectionsRenewalModel`:

``` python
np.random.seed(223)
with npro.handlers.seed(rng_seed=np.random.randint(1, 60)):
sim_data = model1.sample(constants=dict(n_timepoints=30))
with npro.handlers.seed(rng_seed = np.random.randint(1, 60)):
sim_data = model1.sample(constants = dict(n_timepoints=30))

sim_data
```
Expand All @@ -131,14 +146,14 @@ sim_data
1.271196 , 1.3189521, 1.3054799, 1.3165426, 1.291952 , 1.3026639,
1.2619467, 1.2852622, 1.3121517, 1.2888998, 1.2641873, 1.2580931,
1.2545817, 1.3092988, 1.2488269, 1.2397509, 1.2071848, 1.2334517,
1.21868 ], dtype=float32), latent=Array([ 3.7023427, 4.850682 , 6.4314823, 8.26245 , 6.9874763,
7.940377 , 9.171101 , 10.051114 , 10.633459 , 11.729475 ,
12.559867 , 13.422887 , 15.364211 , 17.50132 , 19.206314 ,
21.556652 , 23.78112 , 26.719398 , 28.792412 , 32.40454 ,
36.641006 , 40.135487 , 43.60607 , 48.055103 , 52.829704 ,
60.43277 , 63.97854 , 69.82776 , 74.564415 , 82.88904 ,
88.73811 ], dtype=float32), observed=Array([ 4, 3, 6, 5, 7, 7, 10, 11, 6, 9, 7, 13, 16, 19, 20, 27, 23,
31, 28, 30, 43, 42, 55, 57, 44, 52, 64, 52, 77, 85, 94], dtype=int32))
1.21868 ], dtype=float32), latent=Array([ 2.3215084, 3.0415602, 4.0327816, 5.180868 , 4.381411 ,
4.978916 , 5.750626 , 6.3024273, 6.66758 , 7.354823 ,
7.8755097, 8.416656 , 9.63394 , 10.973988 , 12.043082 ,
13.516833 , 14.911659 , 16.75407 , 18.053928 , 20.318869 ,
22.975292 , 25.166464 , 27.34265 , 30.13236 , 33.126217 ,
37.89362 , 40.11695 , 43.784634 , 46.754696 , 51.974545 ,
55.642136 ], dtype=float32), observed=Array([ 1, 2, 3, 5, 4, 4, 7, 4, 8, 4, 7, 3, 8, 12, 13, 18, 14,
20, 17, 18, 28, 27, 36, 37, 26, 31, 40, 27, 48, 54, 60], dtype=int32))

The `sample()` method of the `RtInfectionsRenewalModel` returns a list
composed of the `Rt` and `infections` sequences.
Expand All @@ -162,11 +177,12 @@ plt.tight_layout()
plt.show()
```

<img
src="getting-started_files/figure-commonmark/basic-fig-output-1.png"
id="basic-fig" />
![Rt and
Infections](getting-started_files/figure-commonmark/basic-fig-output-1.png)

Let’s see how the estimation would go
To fit the model, we can use the `run()` method of the model
`RtInfectionsRenewalModel`; an inherited method from the metaclass
`Model`:

``` python
import jax
Expand All @@ -183,7 +199,8 @@ model1.run(
)
```

Now, let’s investigate the output
Now, let’s investigate the output, particularly the posterior
distribution of the Rt estimates:

``` python
import polars as pl
Expand All @@ -202,6 +219,64 @@ ax.set_yticks([0.5, 1, 2])
ax.set_yscale("log")
```

<img
src="getting-started_files/figure-commonmark/output-rt-output-1.png"
id="output-rt" />
![Rt posterior
distribution](getting-started_files/figure-commonmark/output-rt-output-1.png)

## Architecture of pyrenew

`pyrenew` leverages `numpyro`’s flexibility to build models via
composition. As a principle, most objects in `pyrenew` can be treated as
random variables we can sample. At the top-level `pyrenew` has two
metaclass from which most objects inherit: `RandomVariable` and `Model`.
From them, the following four sub-modules arise:

- The `process` sub-module,
- The `deterministic` sub-module,
- The `observation` sub-module,
- The `latent` sub-module, and
- The `models` sub-module

The first four are collections of instances of `RandomVariable`, and the
last is a collection of instances of `Model`. The following diagram
shows a detailed view of how meta classes, modules, and classes interact
to create the `RtInfectionsRenewalModel` instantiated in the previous
section:

``` mermaid
flowchart LR
rand((RandomVariable\nmetaclass))
models((Model\nmetaclass))
subgraph observations[Observations module]
obs["observed_infections\n(PoissonObservation)"]
end
subgraph latent[Latent module]
inf["latent_infections\n(Infections)"]
i0["I0\n(Infections0)"]
end
subgraph process[Process module]
rt["rt_proc\n(RtRandomWalkProcess)"]
end
subgraph deterministic[Deterministic module]
detpmf["gen_int\n(DeterministicPMF)"]
end
subgraph model[Model module]
model1["model1\n(RtInfectionsRenewalModel)"]
end
rand-->|Inherited by|observations
rand-->|Inherited by|process
rand-->|Inherited by|latent
rand-->|Inherited by|deterministic
models-->|Inherited by|model
detpmf-->|Composes|model1
i0-->|Composes|model1
rt-->|Composes|model1
obs-->|Composes|model1
inf-->|Composes|model1
```
Loading

0 comments on commit 5a06ab2

Please sign in to comment.