diff --git a/model/docs/pyrenew_demo.md b/model/docs/pyrenew_demo.md index 410e192e..3aace074 100644 --- a/model/docs/pyrenew_demo.md +++ b/model/docs/pyrenew_demo.md @@ -4,8 +4,25 @@ This demo simulates some basic renewal process data and then fits to it using `pyrenew`. -You’ll need to install `pyrenew` first. You’ll also need working -installations of `matplotlib`, `numpy`, `jax`, `numpyro`, and `polars` +Assuming you’ve already installed Python and pip, you’ll need to first +install `pyrenew`: + +``` python +python3 -m pip install "pyrenew" +``` + +You’ll also need working installations of `matplotlib`, `numpy`, `jax`, +`numpyro`, and `polars`: + +``` python +python -m pip install "matplotlib" "numpy" "jax" "numpyro" "polars" +``` + +Run the following import section to call external modules and functions +necessary to run the `pyrenew` demo. The `import` statement imports the +module and the `as` statement renames the module for use within this +script. The `from` statement imports a specific function from a module +(named after the `.`) within a package (named before the `.`). ``` python import matplotlib as mpl @@ -20,6 +37,7 @@ import numpyro.distributions as dist ``` python from pyrenew.process import SimpleRandomWalkProcess ``` +To understand the simple random walk process underlying the sampling within the renewal process model, we first examine a single random walk path. Using the `sample` method from an instance of the `SimpleRandomWalkProcess` class, we first create an instance of the `SimpleRandomWalkProcess` class with a normal distribution of mean = 0 and standard deviation = 0.0001 as its input. Next, the `with` statement sets the seed for the random number generator for the duration of the block that follows. Inside the `with` block, the `q_samp = q.sample(duration=100)` generates the sample instance over a duration of 100 time units. Finally, this single random walk process is visualized using `matplot.pyplot` to plot the exponential of the sample instance. ``` python np.random.seed(3312) @@ -32,35 +50,59 @@ plt.plot(np.exp(q_samp[0])) ![](pyrenew_demo_files/figure-commonmark/fig-randwalk-output-1.png) -``` python + +Next, import several additional functions from the `latent` module of the `pyrenew` package to model infections, hospital admissions, initial infections, and hospitalization rate due to infection. + +```{python} from pyrenew.latent import ( Infections, HospitalAdmissions, Infections0, InfectHospRate, ) +``` + +Additionally, import several classes from Pyrenew, including a Poisson observation process, determininstic PMF and variable classes, the Pyrenew hospitalization model, and a renewal modle (Rt) random walk process: +```{python} from pyrenew.observation import PoissonObservation from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.model import HospitalizationsModel from pyrenew.process import RtRandomWalkProcess +``` + +To initialize a model run, we first define initial conditions, including: + +1) deterministic generation time, defined as an instance of the `DeterministicPMF` class, which gives the probability of each possible outcome for a discrete random variable given as a JAX NumPy array of four possible outcomes + +2) initial infections at the start of simulation as a log-normal distribution with mean = 0 and standard deviation = 1 + +3) latent infections as an instance of the `Infections` class with default settings + +4) latent hospitalization process, modeled by first defining the time interval from infections to hospitalizations as a `DeterministicPMF` input with 18 possible outcomes and corresponding probabilities given by the values in the array. The `HospitalAdmissions` function then takes in this defined time interval, as well as defining the rate at which infections are admitted to the hospital due to infection, modeled as a log-normal distribution with mean = `jnp.log(0.05)` and standard deviation = 0.05. + +5) hospitalization observation process, modeled with a Poisson distribution +6) an Rt random walk process with default settings + +``` python # Initializing model components: -# A deterministic generation time +# 1) A deterministic generation time gen_int = DeterministicPMF( (jnp.array([0.25, 0.25, 0.25, 0.25]),), ) -# Initial infections +# 2) Initial infections I0 = Infections0(I0_dist=dist.LogNormal(0, 1)) -# The latent infections process +# 3) The latent infections process latent_infections = Infections() -# A deterministic infection to hosp pmf +# 4) The latent hospitalization process: + +# First, define a deterministic infection to hosp pmf inf_hosp_int = DeterministicPMF( (jnp.array([0, 0, 0,0,0,0,0,0,0,0,0,0,0, 0.25, 0.5, 0.1, 0.1, 0.05]),), ) -# The latent hospitalization process latent_hospitalizations = HospitalAdmissions( infection_to_admission_interval=inf_hosp_int, infect_hosp_rate_dist = InfectHospRate( @@ -68,13 +110,17 @@ latent_hospitalizations = HospitalAdmissions( ), ) -# And observation process for the hospitalizations +# 5) An observation process for the hospitalizations observed_hospitalizations = PoissonObservation() -# And a random walk process (it could be deterministic using +# 6) A random walk process (it could be deterministic using # pyrenew.process.DeterministicProcess()) Rt_process = RtRandomWalkProcess() +``` + +The `HospitalizationsModel` is then initialized using the initial conditions just defined: +``` python # Initializing the model hospmodel = HospitalizationsModel( gen_int=gen_int, @@ -142,39 +188,39 @@ hospmodel.print_summary() mean std median 5.0% 95.0% n_eff r_hat - I0 1.27 1.10 0.97 0.10 2.42 1132.34 1.00 - IHR 0.05 0.00 0.05 0.05 0.05 2306.45 1.00 - Rt0 1.23 0.17 1.23 0.93 1.48 1327.22 1.00 - Rt_transformed_rw_diffs[0] -0.00 0.02 -0.00 -0.04 0.04 1404.95 1.00 - Rt_transformed_rw_diffs[1] 0.00 0.03 0.00 -0.04 0.04 2280.86 1.00 - Rt_transformed_rw_diffs[2] -0.00 0.02 -0.00 -0.04 0.04 2119.83 1.00 - Rt_transformed_rw_diffs[3] 0.00 0.02 -0.00 -0.04 0.04 2196.86 1.00 - Rt_transformed_rw_diffs[4] 0.00 0.02 -0.00 -0.03 0.04 2391.45 1.00 - Rt_transformed_rw_diffs[5] 0.00 0.03 0.00 -0.04 0.04 2043.02 1.00 - Rt_transformed_rw_diffs[6] 0.00 0.02 0.00 -0.04 0.04 1514.40 1.00 - Rt_transformed_rw_diffs[7] -0.00 0.02 -0.00 -0.04 0.04 2619.69 1.00 - Rt_transformed_rw_diffs[8] 0.00 0.03 0.00 -0.04 0.04 1883.84 1.00 - Rt_transformed_rw_diffs[9] 0.00 0.03 0.00 -0.04 0.04 2015.66 1.00 - Rt_transformed_rw_diffs[10] 0.00 0.02 0.00 -0.04 0.04 2045.47 1.00 - Rt_transformed_rw_diffs[11] -0.00 0.03 0.00 -0.04 0.04 1615.10 1.00 - Rt_transformed_rw_diffs[12] 0.00 0.02 0.00 -0.04 0.04 2206.32 1.00 - Rt_transformed_rw_diffs[13] 0.00 0.03 0.00 -0.04 0.04 1175.93 1.00 - Rt_transformed_rw_diffs[14] -0.00 0.03 -0.00 -0.04 0.04 1606.26 1.00 - Rt_transformed_rw_diffs[15] -0.00 0.03 -0.00 -0.04 0.04 2344.62 1.00 - Rt_transformed_rw_diffs[16] -0.00 0.02 0.00 -0.04 0.04 1522.33 1.00 - Rt_transformed_rw_diffs[17] 0.00 0.03 0.00 -0.04 0.04 2157.17 1.00 - Rt_transformed_rw_diffs[18] -0.00 0.02 -0.00 -0.04 0.04 1594.95 1.00 - Rt_transformed_rw_diffs[19] 0.00 0.03 -0.00 -0.04 0.04 1698.70 1.00 - Rt_transformed_rw_diffs[20] 0.00 0.02 0.00 -0.04 0.04 1726.18 1.00 - Rt_transformed_rw_diffs[21] 0.00 0.02 -0.00 -0.04 0.04 2386.35 1.00 - Rt_transformed_rw_diffs[22] 0.00 0.03 0.00 -0.04 0.04 2028.63 1.00 - Rt_transformed_rw_diffs[23] 0.00 0.02 0.00 -0.04 0.03 1669.71 1.00 - Rt_transformed_rw_diffs[24] 0.00 0.02 0.00 -0.04 0.04 2126.33 1.00 - Rt_transformed_rw_diffs[25] -0.00 0.02 -0.00 -0.04 0.04 2119.74 1.00 - Rt_transformed_rw_diffs[26] 0.00 0.03 0.00 -0.04 0.04 2657.91 1.00 - Rt_transformed_rw_diffs[27] -0.00 0.03 0.00 -0.04 0.04 1939.30 1.00 - Rt_transformed_rw_diffs[28] -0.00 0.02 -0.00 -0.04 0.04 1737.84 1.00 - Rt_transformed_rw_diffs[29] -0.00 0.03 -0.00 -0.04 0.04 2105.55 1.00 + I0 1.26 1.09 0.96 0.09 2.41 1114.64 1.00 + IHR 0.05 0.00 0.05 0.05 0.05 2747.53 1.00 + Rt0 1.23 0.17 1.23 0.95 1.52 1533.77 1.00 + Rt_transformed_rw_diffs[0] 0.00 0.03 -0.00 -0.04 0.04 1574.38 1.00 + Rt_transformed_rw_diffs[1] 0.00 0.03 0.00 -0.04 0.04 2557.65 1.00 + Rt_transformed_rw_diffs[2] 0.00 0.02 0.00 -0.04 0.04 2245.16 1.00 + Rt_transformed_rw_diffs[3] 0.00 0.02 0.00 -0.03 0.04 2423.80 1.00 + Rt_transformed_rw_diffs[4] 0.00 0.02 0.00 -0.03 0.04 2461.65 1.00 + Rt_transformed_rw_diffs[5] 0.00 0.02 0.00 -0.04 0.04 2363.57 1.00 + Rt_transformed_rw_diffs[6] 0.00 0.02 0.00 -0.04 0.04 1720.93 1.00 + Rt_transformed_rw_diffs[7] -0.00 0.02 -0.00 -0.04 0.04 3851.66 1.00 + Rt_transformed_rw_diffs[8] -0.00 0.03 -0.00 -0.04 0.04 1824.74 1.00 + Rt_transformed_rw_diffs[9] 0.00 0.02 0.00 -0.04 0.04 1739.32 1.00 + Rt_transformed_rw_diffs[10] 0.00 0.02 0.00 -0.04 0.04 1944.43 1.00 + Rt_transformed_rw_diffs[11] -0.00 0.03 0.00 -0.04 0.04 1558.14 1.00 + Rt_transformed_rw_diffs[12] 0.00 0.02 0.00 -0.04 0.04 2182.35 1.00 + Rt_transformed_rw_diffs[13] -0.00 0.03 0.00 -0.04 0.04 1175.35 1.00 + Rt_transformed_rw_diffs[14] -0.00 0.03 -0.00 -0.04 0.04 1540.25 1.00 + Rt_transformed_rw_diffs[15] -0.00 0.03 -0.00 -0.04 0.04 2367.82 1.00 + Rt_transformed_rw_diffs[16] -0.00 0.02 -0.00 -0.04 0.04 1636.30 1.00 + Rt_transformed_rw_diffs[17] 0.00 0.03 0.00 -0.04 0.04 1978.96 1.00 + Rt_transformed_rw_diffs[18] 0.00 0.02 -0.00 -0.04 0.04 1589.27 1.00 + Rt_transformed_rw_diffs[19] 0.00 0.03 -0.00 -0.04 0.04 1691.06 1.00 + Rt_transformed_rw_diffs[20] -0.00 0.02 -0.00 -0.04 0.04 2562.99 1.00 + Rt_transformed_rw_diffs[21] 0.00 0.02 -0.00 -0.04 0.04 2352.40 1.00 + Rt_transformed_rw_diffs[22] 0.00 0.03 0.00 -0.04 0.04 1971.40 1.00 + Rt_transformed_rw_diffs[23] 0.00 0.02 0.00 -0.04 0.04 2013.90 1.00 + Rt_transformed_rw_diffs[24] 0.00 0.03 0.00 -0.04 0.04 2022.94 1.00 + Rt_transformed_rw_diffs[25] -0.00 0.02 -0.00 -0.04 0.03 1981.62 1.00 + Rt_transformed_rw_diffs[26] 0.00 0.03 0.00 -0.04 0.05 2696.36 1.00 + Rt_transformed_rw_diffs[27] -0.00 0.03 0.00 -0.04 0.04 2003.38 1.00 + Rt_transformed_rw_diffs[28] -0.00 0.02 -0.00 -0.04 0.04 1843.27 1.00 + Rt_transformed_rw_diffs[29] -0.00 0.03 -0.00 -0.04 0.04 1780.88 1.00 Number of divergences: 0 diff --git a/model/docs/pyrenew_demo.qmd b/model/docs/pyrenew_demo.qmd index 78e8ad12..2cc39404 100644 --- a/model/docs/pyrenew_demo.qmd +++ b/model/docs/pyrenew_demo.qmd @@ -6,7 +6,20 @@ engine: jupyter This demo simulates some basic renewal process data and then fits to it using `pyrenew`. -You'll need to install `pyrenew` first. You'll also need working installations of `matplotlib`, `numpy`, `jax`, `numpyro`, and `polars` +Assuming you've already installed Python and pip, you’ll need to first install `pyrenew`: + +```{python} +pip install pyrenew +``` + +You’ll also need working +installations of `matplotlib`, `numpy`, `jax`, `numpyro`, and `polars`: + +```{python} +pip install matplotlib numpy jax numpyro polars +``` + +To begin, run the following import section to call external modules and functions necessary to run the `pyrenew` demo. The `import` statement imports the module and the `as` statement renames the module for use within this script. The `from` statement imports a specific function from a module (named after the `.`) within a package (named before the `.`). ```{python} #| output: false @@ -25,6 +38,8 @@ import numpyro.distributions as dist from pyrenew.process import SimpleRandomWalkProcess ``` +To understand the simple random walk process underlying the sampling within the renewal process model, we first examine a single random walk path. Using the `sample` method from an instance of the `SimpleRandomWalkProcess` class, we first create an instance of the `SimpleRandomWalkProcess` class with a normal distribution of mean = 0 and standard deviation = 0.0001 as its input. Next, the `with` statement sets the seed for the random number generator for the duration of the block that follows. Inside the `with` block, the `q_samp = q.sample(duration=100)` generates the sample instance over a duration of 100 time units. Finally, this single random walk process is visualized using `matplot.pyplot` to plot the exponential of the sample instance. + ```{python} #| label: fig-randwalk #| fig-cap: Random walk example @@ -36,35 +51,58 @@ with seed(rng_seed=np.random.randint(0,1000)): plt.plot(np.exp(q_samp[0])) ``` +Next, import several additional functions from the `latent` module of the `pyrenew` package to model infections, hospital admissions, initial infections, and hospitalization rate due to infection. + ```{python} from pyrenew.latent import ( Infections, HospitalAdmissions, Infections0, InfectHospRate, ) +``` + +Additionally, import several classes from Pyrenew, including a Poisson observation process, determininstic PMF and variable classes, the Pyrenew hospitalization model, and a renewal modle (Rt) random walk process: +```{python} from pyrenew.observation import PoissonObservation from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.model import HospitalizationsModel from pyrenew.process import RtRandomWalkProcess +``` + +To initialize the model, we first define initial conditions, including: + +1) deterministic generation time, defined as an instance of the `DeterministicPMF` class, which gives the probability of each possible outcome for a discrete random variable given as a JAX NumPy array of four possible outcomes +2) initial infections at the start of simulation as a log-normal distribution with mean = 0 and standard deviation = 1 + +3) latent infections as an instance of the `Infections` class with default settings + +4) latent hospitalization process, modeled by first defining the time interval from infections to hospitalizations as a `DeterministicPMF` input with 18 possible outcomes and corresponding probabilities given by the values in the array. The `HospitalAdmissions` function then takes in this defined time interval, as well as defining the rate at which infections are admitted to the hospital due to infection, modeled as a log-normal distribution with mean = `jnp.log(0.05)` and standard deviation = 0.05. + +5) hospitalization observation process, modeled with a Poisson distribution + +6) an Rt random walk process with default settings + +```{python} # Initializing model components: -# A deterministic generation time +# 1) A deterministic generation time gen_int = DeterministicPMF( (jnp.array([0.25, 0.25, 0.25, 0.25]),), ) -# Initial infections +# 2) Initial infections I0 = Infections0(I0_dist=dist.LogNormal(0, 1)) -# The latent infections process +# 3) The latent infections process latent_infections = Infections() -# A deterministic infection to hosp pmf +# 4) The latent hospitalization process: + +# First, define a deterministic infection to hosp pmf inf_hosp_int = DeterministicPMF( (jnp.array([0, 0, 0,0,0,0,0,0,0,0,0,0,0, 0.25, 0.5, 0.1, 0.1, 0.05]),), ) -# The latent hospitalization process latent_hospitalizations = HospitalAdmissions( infection_to_admission_interval=inf_hosp_int, infect_hosp_rate_dist = InfectHospRate( @@ -72,13 +110,17 @@ latent_hospitalizations = HospitalAdmissions( ), ) -# And observation process for the hospitalizations +# 5) An observation process for the hospitalizations observed_hospitalizations = PoissonObservation() -# And a random walk process (it could be deterministic using +# 6) A random walk process (it could be deterministic using # pyrenew.process.DeterministicProcess()) Rt_process = RtRandomWalkProcess() +``` +The `HospitalizationsModel` is then initialized using the initial conditions just defined: + +```{python} # Initializing the model hospmodel = HospitalizationsModel( gen_int=gen_int, @@ -90,6 +132,7 @@ hospmodel = HospitalizationsModel( ) ``` +Next, we sample from the `hospmodel` for 30 time steps and view the output of a single run: ```{python} with seed(rng_seed=np.random.randint(1, 60)): @@ -97,6 +140,8 @@ with seed(rng_seed=np.random.randint(1, 60)): x ``` +Visualizations of the single model output show (top) infections over the 30 time steps, (middle) hospitalizations over the 30 time steps, and (bottom) + ```{python} #| label: fig-hosp #| fig-cap: Infections @@ -109,6 +154,8 @@ for axis in ax[:-1]: axis.set_yscale("log") ``` +To fit the `hospmodel` to the simulated data, we call `hospmodel.run()`, an MCMC algorithm, with the arguments generated in `hospmodel` object, using 1000 warmup stepts and 1000 samples to draw from the posterior distribution of the model parameters. The model is run for `len(x.sampled)-1` time steps with the seed set by `jax.random.PRNGKey()` + ```{python} # from numpyro.infer import MCMC, NUTS hospmodel.run( @@ -121,15 +168,21 @@ hospmodel.run( ) ``` +Print a summary of the model: + ```{python} hospmodel.print_summary() ``` +Next, we will use the `spread_draws` function from the `pyrenew.mcmcutils` module to process the MCMC samples. The `spread_draws` function reformats the samples drawn from the `mcmc.get_samples()` from the `hospmodel`. The samples are simulated Rt values over time. + ```{python} from pyrenew.mcmcutils import spread_draws samps = spread_draws(hospmodel.mcmc.get_samples(), [("Rt", "time")]) ``` +We visualize these samples below, with individual possible Rt estimates over time shown in light blue, and the overall mean estimate Rt shown in dark blue. + ```{python} #| label: fig-sampled-rt #| fig-cap: Posterior Rt diff --git a/model/docs/pyrenew_demo_files/figure-commonmark/fig-randwalk-output-1.png b/model/docs/pyrenew_demo_files/figure-commonmark/fig-randwalk-output-1.png index e3a6cc17..9a794124 100644 Binary files a/model/docs/pyrenew_demo_files/figure-commonmark/fig-randwalk-output-1.png and b/model/docs/pyrenew_demo_files/figure-commonmark/fig-randwalk-output-1.png differ diff --git a/model/docs/pyrenew_demo_files/figure-commonmark/fig-sampled-rt-output-1.png b/model/docs/pyrenew_demo_files/figure-commonmark/fig-sampled-rt-output-1.png index 7f94e989..e31e0699 100644 Binary files a/model/docs/pyrenew_demo_files/figure-commonmark/fig-sampled-rt-output-1.png and b/model/docs/pyrenew_demo_files/figure-commonmark/fig-sampled-rt-output-1.png differ