diff --git a/docs/source/tutorials/pyrenew_demo.qmd b/docs/source/tutorials/pyrenew_demo.qmd deleted file mode 100644 index 598308c8..00000000 --- a/docs/source/tutorials/pyrenew_demo.qmd +++ /dev/null @@ -1,235 +0,0 @@ ---- -title: Pyrenew demo -format: gfm -engine: jupyter ---- - -This demo simulates a basic renewal process data and then fits it using `pyrenew`. - -You’ll need to install `pyrenew` using either poetry or pip. To install `pyrenew` using poetry, run the following command from within the directory containing the `pyrenew` project: - -```bash -poetry install -``` - -To install `pyrenew` using pip, run the following command: - -```bash -pip install git+https://github.com/CDCgov/multisignal-epi-inference@main#subdirectory=model -``` - -To begin, run the following import section to call external modules and functions necessary to run the `pyrenew` demo. The `import` statement imports the module and the `as` statement renames the module for use within this script. The `from` statement imports a specific function from a module (named after the `.`) within a package (named before the `.`). - -```{python} -# | output: false -# | label: loading-pkgs -# | warning: false -import matplotlib as mpl -import matplotlib.pyplot as plt -import jax -import jax.numpy as jnp -import numpy as np -import numpyro -from numpyro.handlers import seed -import numpyro.distributions as dist - -numpyro.set_host_device_count(2) -``` - -```{python} -from pyrenew.process import SimpleRandomWalkProcess -from pyrenew.metaclass import DistributionalRV -``` - -To understand the simple random walk process underlying the sampling within the renewal process model, we first examine a single random walk path. Using the `sample` method from an instance of the `SimpleRandomWalkProcess` class, we first create an instance of the `SimpleRandomWalkProcess` class with a normal distribution of mean = 0 and standard deviation = 0.0001 as its input. Next, the `with` statement sets the seed for the random number generator for the block that follows. Inside the `with` block, the `q_samp = q(n_steps=100)` generates the sample instance over a `n_steps` period of 100 time units. Finally, this single random walk process is visualized using `matplot.pyplot` to plot the exponential of the sample instance. - -```{python} -# | label: fig-randwalk -# | fig-cap: Random walk example -q = SimpleRandomWalkProcess( - "example_random_walk", - step_rv=DistributionalRV(dist.Normal(0, 0.001), "step_rv"), - init_rv=DistributionalRV(dist.Normal(0, 0.001), "init_rv"), -) - -with seed(rng_seed=325): - q_samp = q(n_steps=100) - -plt.plot(np.exp(q_samp[0])) -``` - -Next, import several additional functions from the `latent` module of the `pyrenew` package to model infections and hospital admissions. - -```{python} -from pyrenew.latent import ( - Infections, - HospitalAdmissions, -) -``` - -Additionally, import several classes from Pyrenew, including a Poisson observation process, determininstic PMF and variable classes, the Pyrenew hospitalization model, and a renewal model (Rt) random walk process: - -```{python} -from pyrenew.observation import PoissonObservation -from pyrenew.deterministic import DeterministicPMF, DeterministicVariable -from pyrenew.model import HospitalAdmissionsModel -from pyrenew.latent import ( - InfectionInitializationProcess, - InitializeInfectionsZeroPad, -) -from pyrenew.metaclass import TransformedRandomVariable -import pyrenew.transformation as t -``` - -To initialize the model, we first define initial conditions, including: - -1) deterministic generation time, defined as an instance of the `DeterministicPMF` class, which gives the probability of each possible outcome for a discrete random variable given as a JAX NumPy array of four possible outcomes - -2) initial infections at the start of the renewal process as a log-normal distribution with mean = 0 and standard deviation = 1. Infections before this time are assumed to be 0. - -3) latent infections as an instance of the `Infections` class with default settings - -4) latent hospitalization process, modeled by first defining the time interval from infections to hospitalizations as a `DeterministicPMF` input with 18 possible outcomes and corresponding probabilities given by the values in the array. The `HospitalAdmissions` function then takes in this defined time interval, as well as defining the rate at which infections are admitted to the hospital due to infection, modeled as a log-normal distribution with mean = `jnp.log(0.05)` and standard deviation = 0.05. - -5) hospitalization observation process, modeled with a Poisson distribution - -6) A process to represent $\mathcal{R}(t)$ as a random walk on the log scale, with an inferred initial value and a fixed Normal step-size distribution. - - -```{python} -# Initializing model components: - -# 1) A deterministic generation time -pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25]) -gen_int = DeterministicPMF(pmf_array, name="gen_int") - -# 2) Initial infections -I0 = InfectionInitializationProcess( - "I0_initialization", - DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), - InitializeInfectionsZeroPad(pmf_array.size), - t_unit=1, -) - -# 3) The latent infections process -latent_infections = Infections() - -# 4) The latent hospitalization process: - -# First, define a deterministic infection to hosp pmf -inf_hosp_int = DeterministicPMF( - jnp.array( - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.25, 0.5, 0.1, 0.1, 0.05] - ), - name="inf_hosp_int", -) - -latent_admissions = HospitalAdmissions( - infection_to_admission_interval_rv=inf_hosp_int, - infect_hosp_rate_rv=DistributionalRV( - dist=dist.LogNormal(jnp.log(0.05), 0.05), name="IHR" - ), -) - -# 5) An observation process for the hospital admissions -admissions_process = PoissonObservation("poisson_rv") - -# 6) The random walk on log Rt -Rt_process = TransformedRandomVariable( - "Rt_rv", - base_rv=SimpleRandomWalkProcess( - name="log_rt", - step_rv=DistributionalRV(dist.Normal(0, 0.025), "rw_step_rv"), - init_rv=DistributionalRV(dist.Normal(0, 0.2), "init_log_Rt_rv"), - ), - transforms=t.ExpTransform(), -) -``` - -The `HospitalAdmissionsModel` is then initialized using the initial conditions just defined: - -```{python} -# Initializing the model -hospmodel = HospitalAdmissionsModel( - gen_int_rv=gen_int, - I0_rv=I0, - latent_hosp_admissions_rv=latent_admissions, - hosp_admission_obs_process_rv=admissions_process, - latent_infections_rv=latent_infections, - Rt_process_rv=Rt_process, -) -``` - -Next, we sample from the `hospmodel` for 30 time steps and view the output of a single run: - -```{python} -with seed(rng_seed=np.random.randint(1, 60)): - x = hospmodel.sample(n_timepoints_to_simulate=30) -x -``` - -Visualizations of the single model output show (top) infections over the 30 time steps, (middle) hospital admissions over the 30 time steps, and observed hospital admissions (bottom) - -```{python} -# | label: fig-hosp -# | fig-cap: Infections -fig, ax = plt.subplots(nrows=3, sharex=True) -ax[0].plot(x.latent_infections) -ax[0].set_ylim([1 / 5, 5]) -ax[1].plot(x.latent_hosp_admissions) -ax[2].plot(x.observed_hosp_admissions, "o") -for axis in ax[:-1]: - axis.set_yscale("log") -``` - -To fit the `hospmodel` to the simulated data, we call `hospmodel.run()`, an MCMC algorithm, with the arguments generated in `hospmodel` object, using 1000 warmup stepts and 1000 samples to draw from the posterior distribution of the model parameters. The model is run for `len(x.sampled)-1` time steps with the seed set by `jax.random.PRNGKey()` - -```{python} -# from numpyro.infer import MCMC, NUTS -hospmodel.run( - num_warmup=1000, - num_samples=1000, - data_observed_hosp_admissions=x.observed_hosp_admissions, - rng_key=jax.random.PRNGKey(54), - mcmc_args=dict(progress_bar=False, num_chains=2), -) -``` - -Print a summary of the model: - -```{python} -hospmodel.print_summary() -``` - -Next, we will use the `spread_draws` function from the `pyrenew.mcmcutils` module to process the MCMC samples. The `spread_draws` function reformats the samples drawn from the `mcmc.get_samples()` from the `hospmodel`. The samples are simulated Rt values over time. - -```{python} -from pyrenew.mcmcutils import spread_draws - -samps = spread_draws(hospmodel.mcmc.get_samples(), [("Rt", "time")]) -``` - -We visualize these samples below, with individual possible Rt estimates over time shown in light blue, and the overall mean estimate Rt shown in dark blue. - -```{python} -# | label: fig-sampled-rt -# | fig-cap: Posterior Rt -import numpy as np -import polars as pl - -fig, ax = plt.subplots(figsize=[4, 5]) - -ax.plot(x[0]) -samp_ids = np.random.randint(size=25, low=0, high=999) -for samp_id in samp_ids: - sub_samps = samps.filter(pl.col("draw") == samp_id).sort(pl.col("time")) - ax.plot( - sub_samps.select("time").to_numpy(), - sub_samps.select("Rt").to_numpy(), - color="darkblue", - alpha=0.1, - ) -ax.set_ylim([0.4, 1 / 0.4]) -ax.set_yticks([0.5, 1, 2]) -ax.set_yscale("log") -``` diff --git a/docs/source/tutorials/pyrenew_demo.rst b/docs/source/tutorials/pyrenew_demo.rst deleted file mode 100644 index accf1a3a..00000000 --- a/docs/source/tutorials/pyrenew_demo.rst +++ /dev/null @@ -1,5 +0,0 @@ -.. WARNING -.. Please do not edit this file directly. -.. This file is just a placeholder. -.. For the source file, see: -..