Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

188 random walk refactor #275

Merged
merged 38 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
9298493
Refactor simplerandomwalk.py and associated tests, simplify process.rst
dylanhmorris Jul 17, 2024
f6ceb65
Add TransformedRandomVariable metaclass and tests
dylanhmorris Jul 17, 2024
da8ba1b
Clean up metaclass docs, remove DistributionalRVSample class
dylanhmorris Jul 18, 2024
c0d1cdc
fix dist clash
dylanhmorris Jul 18, 2024
91230c7
Rewrite all tests to pass without RtRandomWalkProcess
dylanhmorris Jul 18, 2024
4ba4a8f
Merge branch 'main' into 188-random-walk-refactor
dylanhmorris Jul 18, 2024
fcd8a2c
Adapt all tutorials
dylanhmorris Jul 18, 2024
62b5001
Custom Rt RV in tutorials
dylanhmorris Jul 18, 2024
e714bdd
Tutorial prior tweaks
dylanhmorris Jul 18, 2024
530870c
More tutorial prior tweaks
dylanhmorris Jul 18, 2024
558d84a
Don't plot nan-padded Rt values
dylanhmorris Jul 18, 2024
e686283
$ to $\mathcal{R}(t)$ throughout basic renewal tutorial
dylanhmorris Jul 18, 2024
7875968
Better handling of nan padded Rt
dylanhmorris Jul 18, 2024
7ed29cd
Add reparam option to DistributionalRV, use for basic tutorial
dylanhmorris Jul 18, 2024
6539024
Merge branch 'main' into 188-random-walk-refactor
dylanhmorris Jul 18, 2024
2b3deb9
Merge branch 'main' into 188-random-walk-refactor
dylanhmorris Jul 18, 2024
5a90808
Merge branch 'main' into 188-random-walk-refactor
dylanhmorris Jul 18, 2024
91a7e65
Merge branch 'main' into 188-random-walk-refactor
dylanhmorris Jul 19, 2024
2ef2075
Update docs/source/tutorials/hospital_admissions_model.qmd
dylanhmorris Jul 19, 2024
2c1c7c7
Update model/src/pyrenew/metaclass.py
dylanhmorris Jul 19, 2024
4617963
Update model/pyproject.toml
dylanhmorris Jul 19, 2024
af8a7d9
Escape backlash
dylanhmorris Jul 22, 2024
5869033
Comment on custom RV
dylanhmorris Jul 22, 2024
1081ae2
Rename admissions model test and fix RNG seed pattern
dylanhmorris Jul 22, 2024
74b553d
Update model/src/pyrenew/process/simplerandomwalk.py
dylanhmorris Jul 22, 2024
ba567ba
Update model/src/pyrenew/process/simplerandomwalk.py
dylanhmorris Jul 22, 2024
44857ff
Update other mermaid diagram
dylanhmorris Jul 22, 2024
8ee433f
Linear scale simulated admissions, with points
dylanhmorris Jul 22, 2024
a0fc6f7
Fix composition diagram
dylanhmorris Jul 22, 2024
931977b
Hospitalization => hospital admission throughout
dylanhmorris Jul 22, 2024
6e7a5e7
np.mean to arviz mean for posterior means, and include both chains
dylanhmorris Jul 22, 2024
088c28f
np.mean to arviz mean for posterior means, and include both chains in…
dylanhmorris Jul 22, 2024
c1f1cc9
Fix typo
dylanhmorris Jul 22, 2024
a033b7d
means to medians in hospital_admissions tutorial, clarify plotting code
dylanhmorris Jul 22, 2024
0afa30e
Means to medians in basic_renewal tutorial, avoid issues with data di…
dylanhmorris Jul 22, 2024
f23adb5
n_timepoints ==> n_steps as argument to SimpleRandomWalkProcess
dylanhmorris Jul 22, 2024
40aefbb
update pyrenew_demo.qmd to use n_steps
dylanhmorris Jul 22, 2024
71bd146
Update tutorials
dylanhmorris Jul 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 1 addition & 30 deletions docs/source/msei_reference/process.rst
Original file line number Diff line number Diff line change
@@ -1,36 +1,7 @@
Random Process
==============

AR Processes
------------

.. automodule:: pyrenew.process.ar
.. automodule:: pyrenew.process
:members:
:undoc-members:
:show-inheritance:

First Difference (AR)
---------------------

.. automodule:: pyrenew.process.firstdifferencear
:members:
:undoc-members:
:show-inheritance:

Reproduction Number Random Walk
-------------------------------

.. automodule:: pyrenew.process.rtrandomwalk
:members:
:undoc-members:
:show-inheritance:

Simple Random Walk
------------------

.. automodule:: pyrenew.process.simplerandomwalk
:members:
:undoc-members:
:show-inheritance:

.. todo:: Determine order and naming of these modules.
111 changes: 69 additions & 42 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ We start by loading the needed components to build a basic renewal model:
# | warning: false
import jax.numpy as jnp
import numpy as np
import numpyro as npro
import numpyro
import numpyro.distributions as dist
from pyrenew.process import RtRandomWalkProcess
from pyrenew.process import SimpleRandomWalkProcess
from pyrenew.latent import (
Infections,
InfectionInitializationProcess,
Expand All @@ -25,10 +25,15 @@ from pyrenew.latent import (
from pyrenew.observation import PoissonObservation
from pyrenew.deterministic import DeterministicPMF
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.metaclass import DistributionalRV
from pyrenew.metaclass import (
RandomVariable,
DistributionalRV,
TransformedRandomVariable,
)
import pyrenew.transformation as t
from numpyro.infer.reparam import LocScaleReparam

npro.set_host_device_count(2)
numpyro.set_host_device_count(2)
```

## Architecture of `RtInfectionsRenewalModel`
Expand All @@ -51,20 +56,20 @@ flowchart LR
models((Model\nmetaclass))

subgraph observations[Observations module]
obs["observation_process\n(PoissonObservation)"]
obs["infection_obs_process_rv\n(PoissonObservation)"]
end

subgraph latent[Latent module]
inf["latent_infections\n(Infections)"]
i0["I0\n(DistributionalRV)"]
inf["latent_infections_rv\n(Infections)"]
i0["I0_rv\n(DistributionalRV)"]
end

subgraph process[Process module]
rt["rt_proc\n(RtRandomWalkProcess)"]
rt["Rt_process_rv\n(Custom class built using SimpleRandomWalk)"]
end

subgraph deterministic[Deterministic module]
detpmf["gen_int\n(DeterministicPMF)"]
detpmf["gen_int_rv\n(DeterministicPMF)"]
end

subgraph model[Model module]
Expand All @@ -85,13 +90,13 @@ flowchart LR
```


The pyrenew package models the real-time reproductive number $R_t$, the average number of secondary infections caused by an infected individual, as a renewal process model. Our basic renewal process model defines five components:
The pyrenew package models the real-time reproductive number $\mathcal{R}(t)$, the average number of secondary infections caused by an infected individual, as a renewal process model. Our basic renewal process model defines five components:

(1) generation interval, the times between infections

(2) initial infections, occurring prior to time $t = 0$

(3) $R_t$, the real-time reproductive number,
(3) $\mathcal{R}(t)$, the time-varying reproductive number,

(4) latent infections, i.e., those infections which are known to exist but are not observed (or not observable), and

Expand All @@ -103,7 +108,7 @@ To initialize these five components within the renewal modeling framework, we es

(2) an instance of the `InfectionInitializationProcess` class, where the number of latent infections immediately before the renewal process begins follows a log-normal distribution with mean = 0 and standard deviation = 1. By specifying `InitializeInfectionsZeroPad`, the latent infections before this time are assumed to be 0.

(3) an instance of the `RtRandomWalkProcess` class with default values
(3) A process to represent $\mathcal{R}(t)$ as a random walk on the log scale, with an inferred initial value and a fixed Normal step-size distribution. For this, we construct a custom `RandomVariable`, `MyRt`.

(4) an instance of the `Infections` class with default values, and

Expand All @@ -112,23 +117,47 @@ To initialize these five components within the renewal modeling framework, we es
```{python}
# | label: creating-elements
# (1) The generation interval (deterministic)
pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25])
pmf_array = jnp.array([0.4, 0.3, 0.2, 0.1])
gen_int = DeterministicPMF(pmf_array, name="gen_int")

# (2) Initial infections (inferred with a prior)
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"),
DistributionalRV(dist=dist.LogNormal(2.5, 1), name="I0"),
InitializeInfectionsZeroPad(pmf_array.size),
t_unit=1,
)

# (3) The random process for Rt
rt_proc = RtRandomWalkProcess(
Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0),
Rt_transform=t.ExpTransform().inv,
Rt_rw_dist=dist.Normal(0, 0.025),
)

# (3) The random walk on log Rt, with an inferred s.d. Here, we
# construct a custom RandomVariable.
class MyRt(RandomVariable):

def validate(self):
pass

def sample(self, n_steps: int, **kwargs) -> tuple:
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))

rt_rv = TransformedRandomVariable(
"Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
dist.Normal(0, sd_rt),
"rw_step_rv",
reparam=LocScaleReparam(0),
),
init_rv=DistributionalRV(
dist.Normal(jnp.log(1), jnp.log(1.2)), "init_log_Rt_rv"
),
),
transforms=t.ExpTransform(),
)
return rt_rv.sample(n_steps=n_steps, **kwargs)


rt_proc = MyRt()

# (4) Latent infection process (which will use 1 and 2)
latent_infections = Infections()
Expand Down Expand Up @@ -158,7 +187,7 @@ The following diagram summarizes how the modules interact via composition; notab
flowchart TB
genint["(1) gen_int\n(DetermnisticPMF)"]
i0["(2) I0\n(InfectionInitializationProcess)"]
rt["(3) rt_proc\n(RtRandomWalkProcess)"]
rt["(3) rt_proc\n(MyRt, the custom RV defined above)"]
inf["(4) latent_infections\n(Infections)"]
obs["(5) observation_process\n(PoissonObservation)"]

Expand All @@ -175,14 +204,13 @@ Using `numpyro`, we can simulate data using the `sample()` member function of `R

```{python}
# | label: simulate
np.random.seed(223)
with npro.handlers.seed(rng_seed=np.random.randint(1, 60)):
sim_data = model1.sample(n_timepoints_to_simulate=30)
with numpyro.handlers.seed(rng_seed=53):
sim_data = model1.sample(n_timepoints_to_simulate=40)

sim_data
```

To understand what has been accomplished here, visualize an $R_t$ sample path (left panel) and infections over time (right panel):
To understand what has been accomplished here, visualize an $\mathcal{R}(t)$ sample path (left panel) and infections over time (right panel):

```{python}
# | label: fig-basic
Expand Down Expand Up @@ -220,7 +248,7 @@ model1.run(
)
```

Now, let's investigate the output, particularly the posterior distribution of the $R_t$ estimates:
Now, let's investigate the output, particularly the posterior distribution of the $\mathcal{R}(t)$ estimates:

```{python}
# | label: fig-output-rt
Expand All @@ -243,19 +271,19 @@ import arviz as az
idata = az.from_numpyro(model1.mcmc)
```

and use the InferenceData to compute the model-fit diagnostics. Here, we show diagnostic summary for the first 10 effective reproduction number $R_t$.
and use the InferenceData to compute the model-fit diagnostics. Here, we show diagnostic summary for the first 10 effective reproduction number $\mathcal{R}(t)$.

```{python}
# | label: diagnostics
diagnostic_stats_summary = az.summary(
idata.posterior["Rt"],
idata.posterior["Rt"][::, ::, 4:], # ignore nan padding
kind="diagnostics",
)

print(diagnostic_stats_summary[:10])
print(diagnostic_stats_summary)
```

Below we use `plot_trace` to inspect the trace of the first 10 $R_t$ estimates.
Below we use `plot_trace` to inspect the trace of the first 10 inferred $\mathcal{R}(t)$ values.

```{python}
# | label: fig-trace-Rt
Expand All @@ -265,20 +293,20 @@ plt.rcParams["figure.constrained_layout.use"] = True
az.plot_trace(
idata.posterior,
var_names=["Rt"],
coords={"Rt_dim_0": np.arange(10)},
coords={"Rt_dim_0": np.arange(4, 14)},
compact=False,
)
plt.show()
```


We inspect the posterior distribution of $R_t$ by plotting the 90% and 50% highest density intervals:
We inspect the posterior distribution of $\mathcal{R}(t)$ by plotting the 90% and 50% highest density intervals:

```{python}
# | label: fig-hdi-Rt
# | fig-cap: High density interval for Effective Reproduction Number
x_data = idata.posterior["Rt_dim_0"]
y_data = idata.posterior["Rt"]
x_data = idata.posterior["Rt_dim_0"][4:]
y_data = idata.posterior["Rt"][::, ::, 4:]

fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
Expand All @@ -300,12 +328,12 @@ az.plot_hdi(
)

# Add mean of the posterior to the figure
mean_Rt = np.mean(idata.posterior["Rt"], axis=1)
axes.plot(x_data, mean_Rt[0], color="C0", label="Mean")
median_ts = y_data.median(dim=["chain", "draw"])
axes.plot(x_data, median_ts, color="C0", label="Median")
axes.legend()
axes.set_title("Posterior Effective Reproduction Number", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("$R_t$", fontsize=10)
axes.set_ylabel("$\\mathcal{R}(t)$", fontsize=10)
plt.show()
```

Expand Down Expand Up @@ -338,11 +366,10 @@ az.plot_hdi(
ax=axes,
)

# Add mean of the posterior to the figure
mean_latent_infection = np.mean(
idata.posterior["all_latent_infections"], axis=1
)
axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean")
# plot the posterior median
median_ts = y_data.median(dim=["chain", "draw"])
axes.plot(x_data, median_ts, color="C0", label="Median")

axes.legend()
axes.set_title("Posterior Latent Infections", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
Expand Down
16 changes: 10 additions & 6 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import numpyro.distributions as dist
from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
from pyrenew.latent import InfectionsWithFeedback
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.process import RtRandomWalkProcess
from pyrenew.metaclass import DistributionalRV
from pyrenew.process import SimpleRandomWalkProcess
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
from pyrenew.latent import (
InfectionInitializationProcess,
InitializeInfectionsExponentialGrowth,
Expand Down Expand Up @@ -60,10 +60,14 @@ latent_infections = InfectionsWithFeedback(
infection_feedback_pmf=gen_int,
)

rt = RtRandomWalkProcess(
Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0),
Rt_transform=t.ExpTransform().inv,
Rt_rw_dist=dist.Normal(0, 0.025),
rt = TransformedRandomVariable(
"Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"),
init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"),
),
transforms=t.ExpTransform(),
)
```

Expand Down
Loading