Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rt with infection feedback #123

Merged
merged 47 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
27b65d2
Working on Rt docs
gvegayon May 16, 2024
1c306bd
Starting off module
gvegayon May 16, 2024
f9951b6
Adding test and starting tutorial
gvegayon May 17, 2024
26fb259
Removing extra file
gvegayon May 17, 2024
8ed3eca
Adding a test and extending documentation
gvegayon May 20, 2024
d8b55f5
Adding test for datautils
gvegayon May 21, 2024
0beb654
Adding more content to the tutorial
gvegayon May 21, 2024
0665b63
Adding a test checking for the calculations of the double conv (expec…
gvegayon May 21, 2024
49d803d
Update model/src/test/test_infectionsrtfeedback.py
gvegayon May 21, 2024
1865216
Clarify variable names in sample_infections_with_feedback()
dylanhmorris May 21, 2024
38e7028
Sign consistency and documentation
dylanhmorris May 21, 2024
880d4c6
Add basic test that infections with feedback reduces to regular infec…
dylanhmorris May 21, 2024
fafe2cf
Add pytest-mpl to pyproject.toml given use of @pytest.mark.mpl_image_…
dylanhmorris May 22, 2024
0119e92
Run precommit
dylanhmorris May 22, 2024
7b96e4a
Clarify pmf input format for sample_infections_with_feedback in docum…
dylanhmorris May 22, 2024
da6ff39
Update lockfile, style infections.py
dylanhmorris May 22, 2024
a814f4a
clarify use of reversed PMFs, remove padding operations that should n…
dylanhmorris May 22, 2024
52bfdb6
Harmonize indexing and remove autopadding; fix manual renewal process…
dylanhmorris May 22, 2024
b63acbe
Update latent admissions test
dylanhmorris May 23, 2024
219e710
Fixing broken tests
gvegayon May 23, 2024
98de6a5
Merge branch 'main' into 9-infection-feedback-in-rt
gvegayon May 23, 2024
6db0fb7
Adding needed padding for convo
gvegayon May 23, 2024
db551da
Adding direction of padding to the docs
gvegayon May 23, 2024
429b988
Adding missing note
gvegayon May 28, 2024
174fbc3
Update model/src/test/test_model_basic_renewal.py
gvegayon May 28, 2024
82a4674
Splitting tests and using pytest.raises
gvegayon May 28, 2024
39cec88
Removing unnecesary call to test_*
gvegayon May 28, 2024
ebe13ed
Update model/src/test/test_datautils.py
gvegayon May 28, 2024
a6041f0
Merge branch 'main' into 9-infection-feedback-in-rt
gvegayon May 28, 2024
fa098f7
Fixing tests and adding test for Infections.sample() raise error
gvegayon May 28, 2024
7888887
Fixing tutorial using old name of sample_infections
gvegayon May 28, 2024
20a71f5
Fixing docstring (missing r""")
gvegayon May 28, 2024
9614971
Deleting call to test
gvegayon May 28, 2024
60dfec4
inf_feedback: either len 1 or len Rt (otherwise error)
gvegayon May 29, 2024
d974c65
solve merge conflicts and add some ignores
gvegayon May 29, 2024
f495487
Replacing I0 for DistributionalRVSample
gvegayon May 29, 2024
6f4a551
Replacing I0 for DistributionalRVSample (vis)
gvegayon May 29, 2024
bd50175
Ensuring DeterministicRV/DistributionalRV.sample returns at least 1d …
gvegayon May 29, 2024
c087b6b
Apply suggestions from code review @damonbayer
gvegayon May 29, 2024
35d0341
Fixing pre-commit
gvegayon May 29, 2024
5eb31cc
Addressing reviewer's comments
gvegayon May 30, 2024
7ea53df
Minor change in docstring for infections test
gvegayon May 30, 2024
b603eb5
Clarifying tutorial extending pyrenew and making sure models match
gvegayon May 30, 2024
b0f72cf
Removing duplicated docstring
gvegayon May 31, 2024
5973703
add docs/extending_pyrenew.md to makefile
damonbayer Jun 3, 2024
3cded51
grammar fix
damonbayer Jun 3, 2024
6f70084
Merge branch 'main' into 9-infection-feedback-in-rt
damonbayer Jun 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions model/Makefile
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,16 @@ docs/example-with-datasets.md: docs/example-with-datasets.qmd
poetry run quarto render docs/example-with-datasets.qmd

docs/py: docs/notebooks
jupyter nbconvert --to python docs/pyrenew_demo.ipynb
jupyter nbconvert --to python docs/getting-started.ipynb
jupyter nbconvert --to python docs/example-with-datasets.ipynb
for i in docs/*.ipynb; do \
jupyter nbconvert --to python $$i ; \
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
done

docs/notebooks:
quarto convert docs/pyrenew_demo.qmd --output docs/pyrenew_demo.ipynb
quarto convert docs/getting-started.qmd --output docs/getting-started.ipynb
quarto convert docs/example-with-datasets.qmd --output \
docs/example-with-datasets.ipynb
for i in docs/*.qmd; do \
if [ $$i -nt $$(basename $$i .qmd).ipynb ]; then \
quarto convert $$i --output docs/$$(basename $$i .qmd).ipynb ; \
fi \
done

test_images:
echo "Generating reference images for tests"
Expand Down
256 changes: 256 additions & 0 deletions model/docs/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
---
title: Extending pyrenew
format: gfm
---

This tutorial illustrates how to extend `pyrenew` with custom `RandomVariable` classes. We will use the `InfectionsWithFeedback` class as an example. The `InfectionsWithFeedback` class is a `RandomVariable` that models the number of infections at time $t$ as a function of the number of infections at time $t - \tau$ and the reproduction number at time $t$. The reproduction number at time $t$ is a function of the *unadjusted* reproduction number at time $t - \tau$ and the number of infections at time $t - \tau$:

$$
\begin{align*}
I(t) & = \mathcal{R}(t)\sum_{\tau=1}^{T_g}I(t - \tau)g(\tau) \\
\mathcal{R}(t) & = \mathcal{R}^u(t)\exp\left(-\gamma(t)\sum_{\tau=1}^{T_f}I(t - \tau)f(\tau)\right)
\end{align*}
$$

Where $\mathcal{R}^u(t)$ is the unadjusted reproduction number, $g(t)$ is the generation interval, $\gamma(t)$ is the infection feedback strength, and $f(t)$ is the infection feedback pmf.

## The expected outcome

Before we start, let's simulate the model with the original `InfectionsWithFeedback` class. We will simulate the model with no observation process. The following code-chunk loads the required libraries and defines the model components:
gvegayon marked this conversation as resolved.
Show resolved Hide resolved

```{python}
#| label: setup
import jax
import jax.numpy as jnp
import numpy as np
import numpyro as npro
import numpyro.distributions as dist
from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
from pyrenew.latent import Infections0, InfectionsWithFeedback
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.process import RtRandomWalkProcess
```

The following code-chunk defines the model components. Notice that for both the generation interval and the infection feedback, we use a deterministic PMF with equal probabilities:

```{python}
#| label: model-components
gen_int = DeterministicPMF(jnp.array([0.25, 0.5, 0.15, 0.1]))

I0 = Infections0(I0_dist=dist.LogNormal(0, 1))

latent_infections = InfectionsWithFeedback(
infection_feedback_strength = DeterministicVariable(1.1),
infection_feedback_pmf = gen_int,
)

rt = RtRandomWalkProcess()
```

With all the components defined, we can build the model:

```{python}
#| label: build1
model0 = RtInfectionsRenewalModel(
gen_int=gen_int,
I0=I0,
latent_infections=latent_infections,
Rt_process=rt,
observation_process=None,
)
```

And simulate it:
gvegayon marked this conversation as resolved.
Show resolved Hide resolved

```{python}
#| label: simulate1
# Sampling and fitting model 0 (with no obs for infections)
np.random.seed(223)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model0_samp = model0.sample(n_timepoints=30)
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
```

```{python}
#| label: fig-simulate1
#| fig-cap: Simulated infections with no observation process
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(model0_samp.latent_infections)
ax.set_xlabel("Time")
ax.set_ylabel("Infections")
plt.show()
```

## Pyrenew's random variable class

### Fundamentals

All instances of PyRenew's `RandomVariable` should have at least three functions: `__init__()`, `validate()`, and `sample()`. The `__init__()` function is the constructor and initializes the class. The `validate()` function checks if the class is correctly initialized. Finally, the `sample()` method contains the core of the class; it should return a tuple or named tuple. The following is a minimal example of a `RandomVariable` class based on `numpyro.distributions.Normal`:

```{python}
#| label: example-rv
from pyrenew.metaclass import RandomVariable

class MyNormal(RandomVariable):
def __init__(self, loc, scale):
self.validate(scale)
self.loc = loc
self.scale = scale
return None

@staticmethod
def validate(self):
if self.scale <= 0:
raise ValueError("Scale must be positive")
return None

def sample(self, **kwargs):
return (dist.Normal(loc=self.loc, scale=self.scale),)
```

The `@staticmethod` decorator exposes the `validate` function to be used outside the class. Next, we show how to build a more complex `RandomVariable` class; the `InfectionsWithFeedback` class.

### The `InfectionsWithFeedback` class

Although returning namedtuples are not strictly required, they are recommended as they make the code more readable. The following code-chunk shows how to create a named tuple for the `InfectionsWithFeedback` class:
damonbayer marked this conversation as resolved.
Show resolved Hide resolved

```{python}
#| label: data-class
from collections import namedtuple

# Creating a tuple to store the output
InfFeedbackSample = namedtuple(
typename='InfFeedbackSample',
field_names=['infections', 'rt'],
defaults=(None, None)
)
```

The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.sample_infections_with_feedback()`. We will also use the `pyrenew.datautils.pad_x_to_match_y()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class:

```{python}
#| label: new-model-def
#| code-line-numbers: true
# Creating the class
from pyrenew.metaclass import RandomVariable
from pyrenew.latent import sample_infections_with_feedback
from pyrenew import datautils as du
from jax.typing import ArrayLike

class InfFeedback(RandomVariable):
"""Latent infections"""

def __init__(
self,
infection_feedback_strength: RandomVariable,
infection_feedback_pmf: RandomVariable,
infections_mean_varname: str = "latent_infections",
) -> None:
"""Constructor"""

self.infection_feedback_strength = infection_feedback_strength
self.infection_feedback_pmf = infection_feedback_pmf
self.infections_mean_varname = infections_mean_varname

return None

def validate(self):
"""
Generally, this method should be more meaningful, but we will skip it for now
"""
return None

def sample(
self,
Rt: ArrayLike,
I0: ArrayLike,
gen_int: ArrayLike,
**kwargs,
) -> tuple:
"""Sample infections with feedback"""

# Baseline infections
I0_vec = du.pad_x_to_match_y(x=I0, y=Rt)
I0_vec = jnp.flip(I0_vec)
gvegayon marked this conversation as resolved.
Show resolved Hide resolved

# Generation interval
gen_int = du.pad_x_to_match_y(x=gen_int, y=Rt)
gen_int_rev = jnp.flip(gen_int)

# Sampling inf feedback strength and adjusting the shape
inf_feedback_strength, *_ = self.infection_feedback_strength.sample(
**kwargs,
)
inf_feedback_strength = du.pad_x_to_match_y(
x=inf_feedback_strength, y=Rt
)

# Sampling inf feedback and adjusting the shape
inf_feedback_pmf, *_ = self.infection_feedback_pmf.sample(**kwargs)
inf_feedback_pmf = du.pad_x_to_match_y(x=inf_feedback_pmf, y=Rt)

# Generating the infections with feedback
all_infections, Rt_adj = sample_infections_with_feedback(
I0=I0_vec,
Rt_raw=Rt,
infection_feedback_strength=inf_feedback_strength,
generation_interval_pmf=gen_int_rev,
infection_feedback_pmf=inf_feedback_pmf,
)

# Storing adjusted Rt for future use
npro.deterministic("Rt_adjusted", Rt_adj)

# Preparing theoutput

return InfFeedbackSample(
infections=all_infections,
rt=Rt_adj,
)
```

The core of the class is implemented in the `sample()` method. Things to highlight from the above code:

1. **Arguments of `sample`**: The `InfFeedback` class will be used within `RtInfectionsRenewalModel` to generate latent infections. During the sampling process, `InfFeedback.sample()` will receive the reproduction number, the initial number of infections, and the generation interval. `RandomVariable.sample()` calls are expected to include the `**kwargs` argument, even if unused.

2. **Calls to `RandomVariable.sample()`**: All calls to `RandomVariable.sample()` are expected to return a tuple or named tuple. In our implementation, we capture the output of `infection_feedback_strength.sample()` and `infection_feedback_pmf.sample()` in the variables `inf_feedback_strength` and `inf_feedback_pmf`, respectively, disregarding the other outputs (i.e., using `*_`).

3. **Saving computed quantities**: Since `Rt_adj` is not generated via `numpyro.sample()`, we use `numpyro.deterministic()` to record the quantity to a site; allowing us to access it later.

4. **Return type of `InfFeedback.sample()`**: As said before, the `sample()` method should return a tuple or named tuple. In our case, we return a named tuple `InfFeedbackSample` with two fields: `infections` and `rt`.

```{python}
#| label: simulation2
latent_infections2 = InfFeedback(
infection_feedback_strength = gen_int,
infection_feedback_pmf = gen_int,
)

model1 = RtInfectionsRenewalModel(
gen_int=gen_int,
I0=I0,
latent_infections=latent_infections2,
Rt_process=rt,
observation_process=None,
)

# Sampling and fitting model 0 (with no obs for infections)
np.random.seed(223)
with npro.handlers.seed(rng_seed=np.random.randint(1, 600)):
model1_samp = model1.sample(n_timepoints=30)
```

Comparing `model0` with `model1`:

```{python}
#| label: fig-simulate2
#| fig-cap: Simulated infections with no observation process
import matplotlib.pyplot as plt
fig, ax = plt.subplots(ncols=2)
ax[0].plot(model0_samp.latent_infections)
ax[1].plot(model1_samp.latent_infections)
ax[0].set_xlabel("Time")
ax[1].set_xlabel("Time")
ax[0].set_ylabel("Infections")
plt.show()
```
gvegayon marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading